11from __future__ import annotations
22
3+ from collections .abc import Mapping
34import dataclasses
45import math
5- from typing import Callable
6+ from typing import Callable , Literal , TypedDict
67
78import ibis
8- from ibis import Deferred
99from ibis .expr import types as ir
1010
1111from mismo import _util , linkage , types
@@ -57,6 +57,61 @@ def haversine(theta: ir.FloatingValue) -> ir.FloatingValue:
5757 return (R_earth * 2 ) * a .sqrt ().asin ()
5858
5959
60+ Coordinates = Mapping [Literal ["lat" , "lon" ], ir .FloatingColumn ]
61+ CoordinatesMapping = Mapping [Literal ["lat" , "lon" ], _util .IntoValue ]
62+
63+
64+ class CoordinatesDict (TypedDict ):
65+ lat : ir .FloatingColumn
66+ lon : ir .FloatingColumn
67+
68+
69+ IntoCoordinates = (
70+ str
71+ | ibis .Deferred
72+ | Coordinates
73+ | CoordinatesMapping
74+ | Callable [[ir .Table ], "IntoCoordinates" ]
75+ )
76+
77+
78+ def default_resolver () -> CoordinatesMapping :
79+ return {"lat" : "lat" , "lon" : "lon" }
80+
81+
82+ def get_coordinate_pair (
83+ table : ir .Table ,
84+ mapping : IntoCoordinates ,
85+ ) -> CoordinatesDict :
86+ """Get the coordinates from a table.
87+
88+ Parameters
89+ ----------
90+ table
91+ The table to get the coordinates from.
92+ mapping
93+ A mapping of column names to use for the coordinates,
94+ or a function that takes a table and returns the coordinate pair.
95+
96+ Returns
97+ -------
98+ coordinates
99+ The coordinates.
100+ """
101+ if callable (mapping ):
102+ called = mapping (table )
103+ return get_coordinate_pair (table , called )
104+
105+ if isinstance (mapping , str ) or isinstance (mapping , ibis .Deferred ):
106+ coords_col = _util .bind_one (table , mapping )
107+ lat : ir .FloatingColumn = coords_col ["lat" ] # ty:ignore[not-subscriptable]
108+ lon : ir .FloatingColumn = coords_col ["lon" ] # ty:ignore[not-subscriptable]
109+ else :
110+ lat : ir .FloatingColumn = _util .bind_one (table , mapping ["lat" ]) # ty:ignore[invalid-assignment]
111+ lon : ir .FloatingColumn = _util .bind_one (table , mapping ["lon" ]) # ty:ignore[invalid-assignment]
112+ return CoordinatesDict (lat = lat , lon = lon )
113+
114+
60115@dataclasses .dataclass (frozen = True )
61116class CoordinateLinker :
62117 """Links two locations together if they are within a certain distance.
@@ -108,9 +163,8 @@ class CoordinateLinker:
108163 ... )
109164 >>> linker = CoordinateLinker(
110165 ... distance_km=1,
111- ... left_coord="latlon",
112- ... right_lat="latitude",
113- ... right_lon="longitude",
166+ ... left_resolver="latlon",
167+ ... right_resolver={"lat": "latitude", "lon": "longitude"},
114168 ... )
115169 >>> linker(left, right).links
116170 ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
@@ -130,127 +184,66 @@ class CoordinateLinker:
130184 """
131185 The (approx) max distance in kilometers that two coords will be blocked together.
132186 """
133- coord : str | Deferred | Callable [[ir .Table ], ir .StructColumn ] | None = None
134- """The column in both tables containing the `struct<lat: float, lon: float>` coordinates.""" # noqa: E501
135- lat : str | Deferred | Callable [[ir .Table ], ir .FloatingColumn ] | None = None
136- """The column in both tables containing the latitude coordinates."""
137- lon : str | Deferred | Callable [[ir .Table ], ir .FloatingColumn ] | None = None
138- """The column in both tables containing the longitude coordinates."""
139- left_coord : str | Deferred | Callable [[ir .Table ], ir .StructColumn ] | None = None
140- """The column in the left tables containing the `struct<lat: float, lon: float>` coordinates.""" # noqa: E501
141- right_coord : str | Deferred | Callable [[ir .Table ], ir .StructColumn ] | None = None
142- """The column in the right tables containing the `struct<lat: float, lon: float>` coordinates.""" # noqa: E501
143- left_lat : str | Deferred | Callable [[ir .Table ], ir .FloatingColumn ] | None = None
144- """The column in the left tables containing the latitude coordinates."""
145- left_lon : str | Deferred | Callable [[ir .Table ], ir .FloatingColumn ] | None = None
146- """The column in the left tables containing the longitude coordinates."""
147- right_lat : str | Deferred | Callable [[ir .Table ], ir .FloatingColumn ] | None = None
148- """The column in the right tables containing the latitude coordinates."""
149- right_lon : str | Deferred | Callable [[ir .Table ], ir .FloatingColumn ] | None = None
150- """The column in the right tables containing the longitude coordinates."""
151187 max_pairs : int | None = None
152188 """The maximum number of pairs that any single block of coordinates can contain.
153189
154190 eg if you have 1000 records all with the same coordinates, this would
155191 naively result in ~(1000 * 1000) / 2 = 500_000 pairs.
156192 If we set max_pairs to less than this, this group of records will be skipped.
157193 """
158-
159- def __post_init__ (self ):
160- ok_subsets = [
161- {"coord" },
162- {"left_coord" , "right_coord" },
163- {"left_coord" , "right_lat" , "right_lon" },
164- {"left_lat" , "left_lon" , "right_coord" },
165- {"left_lat" , "left_lon" , "right_lat" , "right_lon" },
166- {"lat" , "lon" },
167- {"lat" , "right_lat" , "right_lon" },
168- {"left_lat" , "left_lon" , "lon" },
169- ]
170- options = [
171- "coord" ,
172- "left_coord" ,
173- "right_coord" ,
174- "lat" ,
175- "lon" ,
176- "left_lat" ,
177- "left_lon" ,
178- "right_lat" ,
179- "right_lon" ,
180- ]
181- present = {k for k in options if getattr (self , k ) is not None }
182- if present not in ok_subsets :
183- ok_subsets_str = "\n " .join ("- " + str (s ) for s in ok_subsets )
184- raise ValueError (
185- "You must specify exactly one of the following subsets of options:\n "
186- + ok_subsets_str
187- + f"\n You provided:\n { present } "
188- )
189-
190- def _left_coord (
191- self , left : ir .Table
192- ) -> tuple [ir .FloatingColumn , ir .FloatingColumn ]:
193- if self .coord is not None :
194- left_coord = _util .get_column (left , self .coord , on_many = "struct" )
195- return (left_coord .lat , left_coord .lon )
196- if self .left_coord is not None :
197- left_coord = _util .get_column (left , self .left_coord , on_many = "struct" )
198- left_lat = left_coord .lat
199- left_lon = left_coord .lon
200- if self .lat is not None :
201- left_lat = _util .get_column (left , self .lat )
202- if self .lon is not None :
203- left_lon = _util .get_column (left , self .lon )
204- if self .left_lat is not None :
205- left_lat = _util .get_column (left , self .left_lat )
206- if self .left_lon is not None :
207- left_lon = _util .get_column (left , self .left_lon )
208- return left_lat , left_lon
209-
210- def _right_coord (
211- self , right : ir .Table
212- ) -> tuple [ir .FloatingColumn , ir .FloatingColumn ]:
213- if self .coord is not None :
214- right_coord = _util .get_column (right , self .coord , on_many = "struct" )
215- return (right_coord .lat , right_coord .lon )
216- if self .right_coord is not None :
217- right_coord = _util .get_column (right , self .right_coord , on_many = "struct" )
218- right_lat = right_coord .lat
219- right_lon = right_coord .lon
220- if self .lat is not None :
221- right_lat = _util .get_column (right , self .lat )
222- if self .lon is not None :
223- right_lon = _util .get_column (right , self .lon )
224- if self .right_lat is not None :
225- right_lat = _util .get_column (right , self .right_lat )
226- if self .right_lon is not None :
227- right_lon = _util .get_column (right , self .right_lon )
228- return right_lat , right_lon
229-
230- def _join_keys (
231- self , left : ir .Table | ibis .Deferred , right : ir .Table | ibis .Deferred
232- ) -> tuple [tuple [ir .Column , ir .Column ], tuple [ir .Column , ir .Column ]]:
233- left_lat , left_lon = self ._left_coord (left )
234- right_lat , right_lon = self ._right_coord (right )
194+ left_resolver : IntoCoordinates = dataclasses .field (default_factory = default_resolver )
195+ """A specification of how to get the lat and lon values from the left table.
196+
197+ Can be:
198+ - A `str` or `ibis.Deferred`, which are assumed to point to a column
199+ containing a struct with `lat` and `lon` fields.
200+ - A `Mapping` of `{"lat": ..., "lon": ...}` where the values are
201+ column names or `ibis.Deferred` expressions.
202+ - A callable that takes a table and returns one of the above.
203+ """
204+ right_resolver : IntoCoordinates = dataclasses .field (
205+ default_factory = default_resolver
206+ )
207+ """See `left_resolver`, but for the right table."""
208+
209+ def hash_coord (
210+ self , coord : CoordinatesDict , /
211+ ) -> tuple [ir .IntegerValue , ir .IntegerValue ]:
212+ lat , lon = coord ["lat" ], coord ["lon" ]
235213 # We have to use a grid size of ~3x the precision to avoid
236214 # two points falling right on either side of a grid cell boundary
237215 grid_size = self .distance_km * 3
238- left_lat_key , left_lon_key = _bin_lat_lon (left_lat , left_lon , grid_size )
239- right_lat_key , right_lon_key = _bin_lat_lon (right_lat , right_lon , grid_size )
240- return (
241- left_lat_key .name ("lat_binned" ),
242- right_lat_key .name ("lat_binned" ),
243- ), (
244- left_lon_key .name ("lon_binned" ),
245- right_lon_key .name ("lon_binned" ),
246- )
216+ lat_key , lon_key = _bin_lat_lon (lat , lon , grid_size )
217+ return lat_key , lon_key
218+
219+ def hash_left (self , t : ir .Table , / ) -> tuple [ir .IntegerValue , ir .IntegerValue ]:
220+ coords = get_coordinate_pair (t , self .left_resolver )
221+ return self .hash_coord (coords )
222+
223+ def hash_right (self , t : ir .Table , / ) -> tuple [ir .IntegerValue , ir .IntegerValue ]:
224+ coords = get_coordinate_pair (t , self .right_resolver )
225+ return self .hash_coord (coords )
247226
248227 @property
249228 def _key_linker (self ) -> KeyLinker :
250229 import mismo
251230
252- lat_key , lon_key = self ._join_keys (ibis ._ , ibis ._ )
253- return mismo .KeyLinker ([lat_key , lon_key ], max_pairs = self .max_pairs )
231+ def lat_key_left (t : ir .Table ) -> ir .IntegerValue :
232+ return self .hash_left (t )[0 ]
233+
234+ def lon_key_left (t : ir .Table ) -> ir .IntegerValue :
235+ return self .hash_left (t )[1 ]
236+
237+ def lat_key_right (t : ir .Table ) -> ir .IntegerValue :
238+ return self .hash_right (t )[0 ]
239+
240+ def lon_key_right (t : ir .Table ) -> ir .IntegerValue :
241+ return self .hash_right (t )[1 ]
242+
243+ return mismo .KeyLinker (
244+ [(lat_key_left , lat_key_right ), (lon_key_left , lon_key_right )],
245+ max_pairs = self .max_pairs ,
246+ )
254247
255248 def __join_condition__ (self , left : ir .Table , right : ir .Table ) -> ir .BooleanValue :
256249 return self ._key_linker .__join_condition__ (left , right )
0 commit comments