Skip to content

Commit 0378d34

Browse files
committed
feat: overhaul many things
1 parent c12b7f6 commit 0378d34

File tree

21 files changed

+344
-267
lines changed

21 files changed

+344
-267
lines changed

mismo/_counts_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ibis import _
88
from ibis.expr import types as ir
99

10-
from mismo.types._table_wrapper import TableWrapper
10+
from mismo.types._wrapper import TableWrapper
1111

1212
if TYPE_CHECKING:
1313
import altair as alt

mismo/_util.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,28 @@ def cases(
8989
return builder.else_(else_).end()
9090

9191

92-
@overload
93-
def bind(t: ibis.Table, ref: Any, /) -> tuple[ibis.Value, ...]: ...
94-
@overload
95-
def bind(t: ibis.Deferred, ref: Any, /) -> tuple[ibis.Deferred]: ...
92+
IntoValue = str | int | ibis.Deferred | ibis.Value | Callable[[ibis.Table], ibis.Value]
9693

9794

98-
def bind(t: ibis.Deferred | ibis.Table, ref: Any) -> tuple[ibis.Value, ...]:
95+
@overload
96+
def bind(
97+
t: ibis.Table,
98+
ref: IntoValue | Iterable[IntoValue] | Mapping[str, IntoValue],
99+
/,
100+
) -> tuple[ibis.Value, ...]: ...
101+
@overload
102+
def bind(
103+
t: ibis.Deferred,
104+
ref: IntoValue | Iterable[IntoValue] | Mapping[str, IntoValue],
105+
/,
106+
) -> tuple[ibis.Deferred]: ...
107+
108+
109+
def bind(
110+
t: ibis.Deferred | ibis.Table,
111+
ref: IntoValue | Iterable[IntoValue] | Mapping[str, IntoValue],
112+
/,
113+
) -> tuple[ibis.Value, ...]:
99114
"""Reference into a table to get Columns and Scalars.
100115
101116
ibis._.bind(ref) does not work because it returns another Deferred.
@@ -110,49 +125,21 @@ def bind(t: ibis.Deferred | ibis.Table, ref: Any) -> tuple[ibis.Value, ...]:
110125

111126

112127
@overload
113-
def bind_one(t: ibis.Table, ref: Any, /) -> ibis.Value: ...
128+
def bind_one(t: ibis.Table, ref: IntoValue, /) -> ibis.Value: ...
114129
@overload
115-
def bind_one(t: ibis.Deferred, ref: Any, /) -> ibis.Deferred: ...
130+
def bind_one(t: ibis.Deferred, ref: IntoValue, /) -> ibis.Deferred: ...
116131

117132

118-
def bind_one(t: ibis.Deferred | ibis.Table, ref: Any) -> ibis.Value | ibis.Deferred:
133+
def bind_one(
134+
t: ibis.Deferred | ibis.Table, ref: IntoValue, /
135+
) -> ibis.Value | ibis.Deferred:
119136
"""Like bind(), but ensure that exactly one value is returned."""
120137
vals = bind(t, ref)
121138
if len(vals) != 1:
122139
raise ValueError(f"Expected 1 value, got {len(vals)} from {ref}")
123140
return vals[0]
124141

125142

126-
def get_column(
127-
t: ir.Table, ref: Any, *, on_many: Literal["error", "struct"] = "error"
128-
) -> ir.Column:
129-
"""Get a column from a table using some sort of reference to the column.
130-
131-
ref can be a string, a Deferred, a callable, an ibis selector, etc.
132-
133-
Parameters
134-
----------
135-
t :
136-
The table
137-
ref :
138-
The reference to the column
139-
on_many :
140-
What to do if ref returns multiple columns. If "error", raise an error.
141-
If "struct", return a StructColumn containing all the columns.
142-
"""
143-
cols = bind(t, ref)
144-
if isinstance(t, ibis.Deferred):
145-
# This is by definition a single column
146-
return cols[0]
147-
if len(cols) != 1:
148-
if on_many == "error":
149-
raise ValueError(f"Expected 1 column, got {len(cols)}")
150-
if on_many == "struct":
151-
return ibis.struct({c.get_name(): c for c in cols})
152-
raise ValueError(f"on_many must be 'error' or 'struct'. Got {on_many}")
153-
return cols[0]
154-
155-
156143
def ensure_ibis(
157144
val: Any, type: str | dt.DataType | None = None
158145
) -> ir.Value | ibis.Deferred:

mismo/arrays/_array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def array_filter_isin_other(
110110
The table with a new column named following `result_format` with the
111111
filtered array.
112112
""" # noqa E501
113-
array_col = _util.get_column(t, array)
114-
t = t.mutate(__array=array_col, __id=ibis.row_number())
113+
array_val = _util.bind_one(t, array)
114+
t = t.mutate(__array=array_val, __id=ibis.row_number())
115115
temp = t.select("__id", __unnested=_.__array.unnest())
116116
# When we re-.collect() items below, the order matters,
117117
# but the .filter() can mess up the order, so we need to
@@ -131,7 +131,7 @@ def array_filter_isin_other(
131131
[], _.__filtered
132132
)
133133
).drop("__array")
134-
result_name = result_format.format(name=array_col.get_name())
134+
result_name = result_format.format(name=array_val.get_name())
135135
return re_joined.rename({result_name: "__filtered"})
136136

137137

mismo/lib/email/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66
from ibis.expr import types as ir
77

8-
from mismo._util import cases, get_column
8+
from mismo._util import bind_one, cases
99
from mismo.arrays import array_combinations, array_min
1010
from mismo.compare import MatchLevel
1111
from mismo.linker import UnnestLinker
@@ -164,7 +164,7 @@ def __init__(
164164
def prepare_for_fast_linking(self, t: ir.Table) -> ir.Table:
165165
"""Add a column with the parsed and normalized email addresses."""
166166
return t.mutate(
167-
get_column(t, self.column)
167+
bind_one(t, self.column)
168168
.map(
169169
lambda email: ParsedEmail(
170170
clean_email(email, normalize=True)

mismo/lib/geo/_latlon.py

Lines changed: 104 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

3+
from collections.abc import Mapping
34
import dataclasses
45
import math
5-
from typing import Callable
6+
from typing import Callable, Literal, TypedDict
67

78
import ibis
8-
from ibis import Deferred
99
from ibis.expr import types as ir
1010

1111
from mismo import _util, linkage, types
@@ -57,6 +57,61 @@ def haversine(theta: ir.FloatingValue) -> ir.FloatingValue:
5757
return (R_earth * 2) * a.sqrt().asin()
5858

5959

60+
Coordinates = Mapping[Literal["lat", "lon"], ir.FloatingColumn]
61+
CoordinatesMapping = Mapping[Literal["lat", "lon"], _util.IntoValue]
62+
63+
64+
class CoordinatesDict(TypedDict):
65+
lat: ir.FloatingColumn
66+
lon: ir.FloatingColumn
67+
68+
69+
IntoCoordinates = (
70+
str
71+
| ibis.Deferred
72+
| Coordinates
73+
| CoordinatesMapping
74+
| Callable[[ir.Table], "IntoCoordinates"]
75+
)
76+
77+
78+
def default_resolver() -> CoordinatesMapping:
79+
return {"lat": "lat", "lon": "lon"}
80+
81+
82+
def get_coordinate_pair(
83+
table: ir.Table,
84+
mapping: IntoCoordinates,
85+
) -> CoordinatesDict:
86+
"""Get the coordinates from a table.
87+
88+
Parameters
89+
----------
90+
table
91+
The table to get the coordinates from.
92+
mapping
93+
A mapping of column names to use for the coordinates,
94+
or a function that takes a table and returns the coordinate pair.
95+
96+
Returns
97+
-------
98+
coordinates
99+
The coordinates.
100+
"""
101+
if callable(mapping):
102+
called = mapping(table)
103+
return get_coordinate_pair(table, called)
104+
105+
if isinstance(mapping, str) or isinstance(mapping, ibis.Deferred):
106+
coords_col = _util.bind_one(table, mapping)
107+
lat: ir.FloatingColumn = coords_col["lat"] # ty:ignore[not-subscriptable]
108+
lon: ir.FloatingColumn = coords_col["lon"] # ty:ignore[not-subscriptable]
109+
else:
110+
lat: ir.FloatingColumn = _util.bind_one(table, mapping["lat"]) # ty:ignore[invalid-assignment]
111+
lon: ir.FloatingColumn = _util.bind_one(table, mapping["lon"]) # ty:ignore[invalid-assignment]
112+
return CoordinatesDict(lat=lat, lon=lon)
113+
114+
60115
@dataclasses.dataclass(frozen=True)
61116
class CoordinateLinker:
62117
"""Links two locations together if they are within a certain distance.
@@ -108,9 +163,8 @@ class CoordinateLinker:
108163
... )
109164
>>> linker = CoordinateLinker(
110165
... distance_km=1,
111-
... left_coord="latlon",
112-
... right_lat="latitude",
113-
... right_lon="longitude",
166+
... left_resolver="latlon",
167+
... right_resolver={"lat": "latitude", "lon": "longitude"},
114168
... )
115169
>>> linker(left, right).links
116170
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
@@ -130,127 +184,66 @@ class CoordinateLinker:
130184
"""
131185
The (approx) max distance in kilometers that two coords will be blocked together.
132186
"""
133-
coord: str | Deferred | Callable[[ir.Table], ir.StructColumn] | None = None
134-
"""The column in both tables containing the `struct<lat: float, lon: float>` coordinates.""" # noqa: E501
135-
lat: str | Deferred | Callable[[ir.Table], ir.FloatingColumn] | None = None
136-
"""The column in both tables containing the latitude coordinates."""
137-
lon: str | Deferred | Callable[[ir.Table], ir.FloatingColumn] | None = None
138-
"""The column in both tables containing the longitude coordinates."""
139-
left_coord: str | Deferred | Callable[[ir.Table], ir.StructColumn] | None = None
140-
"""The column in the left tables containing the `struct<lat: float, lon: float>` coordinates.""" # noqa: E501
141-
right_coord: str | Deferred | Callable[[ir.Table], ir.StructColumn] | None = None
142-
"""The column in the right tables containing the `struct<lat: float, lon: float>` coordinates.""" # noqa: E501
143-
left_lat: str | Deferred | Callable[[ir.Table], ir.FloatingColumn] | None = None
144-
"""The column in the left tables containing the latitude coordinates."""
145-
left_lon: str | Deferred | Callable[[ir.Table], ir.FloatingColumn] | None = None
146-
"""The column in the left tables containing the longitude coordinates."""
147-
right_lat: str | Deferred | Callable[[ir.Table], ir.FloatingColumn] | None = None
148-
"""The column in the right tables containing the latitude coordinates."""
149-
right_lon: str | Deferred | Callable[[ir.Table], ir.FloatingColumn] | None = None
150-
"""The column in the right tables containing the longitude coordinates."""
151187
max_pairs: int | None = None
152188
"""The maximum number of pairs that any single block of coordinates can contain.
153189
154190
eg if you have 1000 records all with the same coordinates, this would
155191
naively result in ~(1000 * 1000) / 2 = 500_000 pairs.
156192
If we set max_pairs to less than this, this group of records will be skipped.
157193
"""
158-
159-
def __post_init__(self):
160-
ok_subsets = [
161-
{"coord"},
162-
{"left_coord", "right_coord"},
163-
{"left_coord", "right_lat", "right_lon"},
164-
{"left_lat", "left_lon", "right_coord"},
165-
{"left_lat", "left_lon", "right_lat", "right_lon"},
166-
{"lat", "lon"},
167-
{"lat", "right_lat", "right_lon"},
168-
{"left_lat", "left_lon", "lon"},
169-
]
170-
options = [
171-
"coord",
172-
"left_coord",
173-
"right_coord",
174-
"lat",
175-
"lon",
176-
"left_lat",
177-
"left_lon",
178-
"right_lat",
179-
"right_lon",
180-
]
181-
present = {k for k in options if getattr(self, k) is not None}
182-
if present not in ok_subsets:
183-
ok_subsets_str = "\n".join("- " + str(s) for s in ok_subsets)
184-
raise ValueError(
185-
"You must specify exactly one of the following subsets of options:\n"
186-
+ ok_subsets_str
187-
+ f"\nYou provided:\n{present}"
188-
)
189-
190-
def _left_coord(
191-
self, left: ir.Table
192-
) -> tuple[ir.FloatingColumn, ir.FloatingColumn]:
193-
if self.coord is not None:
194-
left_coord = _util.get_column(left, self.coord, on_many="struct")
195-
return (left_coord.lat, left_coord.lon)
196-
if self.left_coord is not None:
197-
left_coord = _util.get_column(left, self.left_coord, on_many="struct")
198-
left_lat = left_coord.lat
199-
left_lon = left_coord.lon
200-
if self.lat is not None:
201-
left_lat = _util.get_column(left, self.lat)
202-
if self.lon is not None:
203-
left_lon = _util.get_column(left, self.lon)
204-
if self.left_lat is not None:
205-
left_lat = _util.get_column(left, self.left_lat)
206-
if self.left_lon is not None:
207-
left_lon = _util.get_column(left, self.left_lon)
208-
return left_lat, left_lon
209-
210-
def _right_coord(
211-
self, right: ir.Table
212-
) -> tuple[ir.FloatingColumn, ir.FloatingColumn]:
213-
if self.coord is not None:
214-
right_coord = _util.get_column(right, self.coord, on_many="struct")
215-
return (right_coord.lat, right_coord.lon)
216-
if self.right_coord is not None:
217-
right_coord = _util.get_column(right, self.right_coord, on_many="struct")
218-
right_lat = right_coord.lat
219-
right_lon = right_coord.lon
220-
if self.lat is not None:
221-
right_lat = _util.get_column(right, self.lat)
222-
if self.lon is not None:
223-
right_lon = _util.get_column(right, self.lon)
224-
if self.right_lat is not None:
225-
right_lat = _util.get_column(right, self.right_lat)
226-
if self.right_lon is not None:
227-
right_lon = _util.get_column(right, self.right_lon)
228-
return right_lat, right_lon
229-
230-
def _join_keys(
231-
self, left: ir.Table | ibis.Deferred, right: ir.Table | ibis.Deferred
232-
) -> tuple[tuple[ir.Column, ir.Column], tuple[ir.Column, ir.Column]]:
233-
left_lat, left_lon = self._left_coord(left)
234-
right_lat, right_lon = self._right_coord(right)
194+
left_resolver: IntoCoordinates = dataclasses.field(default_factory=default_resolver)
195+
"""A specification of how to get the lat and lon values from the left table.
196+
197+
Can be:
198+
- A `str` or `ibis.Deferred`, which are assumed to point to a column
199+
containing a struct with `lat` and `lon` fields.
200+
- A `Mapping` of `{"lat": ..., "lon": ...}` where the values are
201+
column names or `ibis.Deferred` expressions.
202+
- A callable that takes a table and returns one of the above.
203+
"""
204+
right_resolver: IntoCoordinates = dataclasses.field(
205+
default_factory=default_resolver
206+
)
207+
"""See `left_resolver`, but for the right table."""
208+
209+
def hash_coord(
210+
self, coord: CoordinatesDict, /
211+
) -> tuple[ir.IntegerValue, ir.IntegerValue]:
212+
lat, lon = coord["lat"], coord["lon"]
235213
# We have to use a grid size of ~3x the precision to avoid
236214
# two points falling right on either side of a grid cell boundary
237215
grid_size = self.distance_km * 3
238-
left_lat_key, left_lon_key = _bin_lat_lon(left_lat, left_lon, grid_size)
239-
right_lat_key, right_lon_key = _bin_lat_lon(right_lat, right_lon, grid_size)
240-
return (
241-
left_lat_key.name("lat_binned"),
242-
right_lat_key.name("lat_binned"),
243-
), (
244-
left_lon_key.name("lon_binned"),
245-
right_lon_key.name("lon_binned"),
246-
)
216+
lat_key, lon_key = _bin_lat_lon(lat, lon, grid_size)
217+
return lat_key, lon_key
218+
219+
def hash_left(self, t: ir.Table, /) -> tuple[ir.IntegerValue, ir.IntegerValue]:
220+
coords = get_coordinate_pair(t, self.left_resolver)
221+
return self.hash_coord(coords)
222+
223+
def hash_right(self, t: ir.Table, /) -> tuple[ir.IntegerValue, ir.IntegerValue]:
224+
coords = get_coordinate_pair(t, self.right_resolver)
225+
return self.hash_coord(coords)
247226

248227
@property
249228
def _key_linker(self) -> KeyLinker:
250229
import mismo
251230

252-
lat_key, lon_key = self._join_keys(ibis._, ibis._)
253-
return mismo.KeyLinker([lat_key, lon_key], max_pairs=self.max_pairs)
231+
def lat_key_left(t: ir.Table) -> ir.IntegerValue:
232+
return self.hash_left(t)[0]
233+
234+
def lon_key_left(t: ir.Table) -> ir.IntegerValue:
235+
return self.hash_left(t)[1]
236+
237+
def lat_key_right(t: ir.Table) -> ir.IntegerValue:
238+
return self.hash_right(t)[0]
239+
240+
def lon_key_right(t: ir.Table) -> ir.IntegerValue:
241+
return self.hash_right(t)[1]
242+
243+
return mismo.KeyLinker(
244+
[(lat_key_left, lat_key_right), (lon_key_left, lon_key_right)],
245+
max_pairs=self.max_pairs,
246+
)
254247

255248
def __join_condition__(self, left: ir.Table, right: ir.Table) -> ir.BooleanValue:
256249
return self._key_linker.__join_condition__(left, right)

0 commit comments

Comments
 (0)