Skip to content

Commit 3e26bba

Browse files
committed
refactor: update various modules to use new wrapper and improve code
Update _util.py with code simplifications Update arrays, linker, sets, and tf modules to use new wrapper Add assert_equal utility function to tests
1 parent fad8153 commit 3e26bba

File tree

7 files changed

+49
-55
lines changed

7 files changed

+49
-55
lines changed

mismo/_counts_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ibis import _
88
from ibis.expr import types as ir
99

10-
from mismo.types._table_wrapper import TableWrapper
10+
from mismo.types._wrapper import TableWrapper
1111

1212
if TYPE_CHECKING:
1313
import altair as alt

mismo/_util.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,28 @@ def cases(
8989
return builder.else_(else_).end()
9090

9191

92-
@overload
93-
def bind(t: ibis.Table, ref: Any, /) -> tuple[ibis.Value, ...]: ...
94-
@overload
95-
def bind(t: ibis.Deferred, ref: Any, /) -> tuple[ibis.Deferred]: ...
92+
IntoValue = str | int | ibis.Deferred | ibis.Value | Callable[[ibis.Table], ibis.Value]
9693

9794

98-
def bind(t: ibis.Deferred | ibis.Table, ref: Any) -> tuple[ibis.Value, ...]:
95+
@overload
96+
def bind(
97+
t: ibis.Table,
98+
ref: IntoValue | Iterable[IntoValue] | Mapping[str, IntoValue],
99+
/,
100+
) -> tuple[ibis.Value, ...]: ...
101+
@overload
102+
def bind(
103+
t: ibis.Deferred,
104+
ref: IntoValue | Iterable[IntoValue] | Mapping[str, IntoValue],
105+
/,
106+
) -> tuple[ibis.Deferred]: ...
107+
108+
109+
def bind(
110+
t: ibis.Deferred | ibis.Table,
111+
ref: IntoValue | Iterable[IntoValue] | Mapping[str, IntoValue],
112+
/,
113+
) -> tuple[ibis.Value, ...]:
99114
"""Reference into a table to get Columns and Scalars.
100115
101116
ibis._.bind(ref) does not work because it returns another Deferred.
@@ -110,49 +125,21 @@ def bind(t: ibis.Deferred | ibis.Table, ref: Any) -> tuple[ibis.Value, ...]:
110125

111126

112127
@overload
113-
def bind_one(t: ibis.Table, ref: Any, /) -> ibis.Value: ...
128+
def bind_one(t: ibis.Table, ref: IntoValue, /) -> ibis.Value: ...
114129
@overload
115-
def bind_one(t: ibis.Deferred, ref: Any, /) -> ibis.Deferred: ...
130+
def bind_one(t: ibis.Deferred, ref: IntoValue, /) -> ibis.Deferred: ...
116131

117132

118-
def bind_one(t: ibis.Deferred | ibis.Table, ref: Any) -> ibis.Value | ibis.Deferred:
133+
def bind_one(
134+
t: ibis.Deferred | ibis.Table, ref: IntoValue, /
135+
) -> ibis.Value | ibis.Deferred:
119136
"""Like bind(), but ensure that exactly one value is returned."""
120137
vals = bind(t, ref)
121138
if len(vals) != 1:
122139
raise ValueError(f"Expected 1 value, got {len(vals)} from {ref}")
123140
return vals[0]
124141

125142

126-
def get_column(
127-
t: ir.Table, ref: Any, *, on_many: Literal["error", "struct"] = "error"
128-
) -> ir.Column:
129-
"""Get a column from a table using some sort of reference to the column.
130-
131-
ref can be a string, a Deferred, a callable, an ibis selector, etc.
132-
133-
Parameters
134-
----------
135-
t :
136-
The table
137-
ref :
138-
The reference to the column
139-
on_many :
140-
What to do if ref returns multiple columns. If "error", raise an error.
141-
If "struct", return a StructColumn containing all the columns.
142-
"""
143-
cols = bind(t, ref)
144-
if isinstance(t, ibis.Deferred):
145-
# This is by definition a single column
146-
return cols[0]
147-
if len(cols) != 1:
148-
if on_many == "error":
149-
raise ValueError(f"Expected 1 column, got {len(cols)}")
150-
if on_many == "struct":
151-
return ibis.struct({c.get_name(): c for c in cols})
152-
raise ValueError(f"on_many must be 'error' or 'struct'. Got {on_many}")
153-
return cols[0]
154-
155-
156143
def ensure_ibis(
157144
val: Any, type: str | dt.DataType | None = None
158145
) -> ir.Value | ibis.Deferred:

mismo/arrays/_array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def array_filter_isin_other(
110110
The table with a new column named following `result_format` with the
111111
filtered array.
112112
""" # noqa E501
113-
array_col = _util.get_column(t, array)
114-
t = t.mutate(__array=array_col, __id=ibis.row_number())
113+
array_val = _util.bind_one(t, array)
114+
t = t.mutate(__array=array_val, __id=ibis.row_number())
115115
temp = t.select("__id", __unnested=_.__array.unnest())
116116
# When we re-.collect() items below, the order matters,
117117
# but the .filter() can mess up the order, so we need to
@@ -131,7 +131,7 @@ def array_filter_isin_other(
131131
[], _.__filtered
132132
)
133133
).drop("__array")
134-
result_name = result_format.format(name=array_col.get_name())
134+
result_name = result_format.format(name=array_val.get_name())
135135
return re_joined.rename({result_name: "__filtered"})
136136

137137

mismo/linker/_lsh.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ibis import _
77
from ibis.expr import types as ir
88

9-
from mismo._util import get_column
9+
from mismo._util import bind_one
1010
from mismo.arrays import array_choice
1111
from mismo.linker import _common
1212

@@ -79,8 +79,8 @@ def __init__(
7979

8080
def __call__(self, left: ir.Table, right: ir.Table) -> ir.Table:
8181
"""Block two tables using Minhash LSH."""
82-
left_terms = get_column(left, self.terms_column)
83-
right_terms = get_column(right, self.terms_column)
82+
left_terms = bind_one(left, self.terms_column)
83+
right_terms = bind_one(right, self.terms_column)
8484
keys_name = self.keys_column.format(terms_column=left_terms.get_name())
8585
left = left.mutate(
8686
minhash_lsh_keys(

mismo/sets/_tfidf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def add_array_value_counts(
166166
│ NULL │ NULL │
167167
└──────────────────────────────┴──────────────────────────────────┘
168168
""" # noqa: E501
169-
t = t.mutate(__terms=_util.get_column(t, column))
169+
t = t.mutate(__terms=_util.bind_one(t, column))
170170
normalized = (
171171
t.select("__terms")
172172
.distinct()

mismo/tests/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import ibis
66
from ibis.expr import types as ir
7+
from ibis.tests.util import assert_equal as ibis_assert_equal
78
import pandas as pd
89
import pytest
910

@@ -47,6 +48,10 @@ def assert_tables_equal(
4748
assert left_records == right_records
4849

4950

51+
def assert_equal(left: ibis.Value, right: ibis.Value) -> None:
52+
ibis_assert_equal(left, right)
53+
54+
5055
def make_record_approx(record: dict) -> dict:
5156
return {k: make_float_comparable(v) for k, v in record.items()}
5257

mismo/tf/_tf.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ibis.expr import types as ir
1010

1111
from mismo import _util
12-
from mismo.types._table_wrapper import TableWrapper
12+
from mismo.types._wrapper import TableWrapper
1313

1414
T = TypeVar("T", bound=ir.Column)
1515

@@ -52,24 +52,24 @@ def add_frequencies(
5252
self,
5353
table: ibis.Table,
5454
*,
55-
column: str | ibis.Deferred | ibis.Column | None = None,
55+
column: str | ibis.Deferred | ibis.Value | None = None,
5656
name_as: str | None = None,
57-
default: Literal["1/N"] | int | float = "1/N",
57+
default: Literal["1/N"] | int | float | ibis.Scalar = "1/N",
5858
) -> ibis.Table:
5959
"""Add frequency columns to the given table."""
6060
if name_as is None:
6161
name_as = f"frequency_{self.name}"
6262
if column is None:
6363
column = self.name
6464

65+
default_resolved: ibis.Scalar
6566
if default == "1/N" or default == "1/n":
6667
n_total = table.count().as_scalar()
67-
default = (1 / n_total).cast("float64")
68+
default_resolved = (1 / n_total).cast("float64")
6869
elif isinstance(default, ibis.Scalar):
69-
default = default.cast("float64")
70+
default_resolved = default.cast("float64") # ty:ignore[invalid-assignment]
7071
else:
71-
default = ibis.literal(default, "float64")
72-
72+
default_resolved = ibis.literal(default, "float64")
7373
table_column = _util.bind_one(table, column)
7474

7575
unique_name = _util.unique_name("join_key")
@@ -82,7 +82,7 @@ def add_frequencies(
8282
.as_table()
8383
.distinct()
8484
.anti_join(stats_raw, unique_name)
85-
.mutate(default.name(name_as))
85+
.mutate(default_resolved.name(name_as))
8686
)
8787
stats = ibis.union(stats_raw, filler)
8888
assert stats.columns == (unique_name, name_as)
@@ -93,6 +93,8 @@ def add_frequencies(
9393

9494

9595
class TermFrequencyModel:
96+
columns: dict[str, ir.Column]
97+
9698
def __init__(
9799
self,
98100
columns: ibis.Table
@@ -102,7 +104,7 @@ def __init__(
102104
/,
103105
):
104106
if isinstance(columns, Mapping):
105-
self.columns = columns
107+
self.columns = dict(columns)
106108
elif isinstance(columns, ibis.Table):
107109
self.columns = {col: columns[col] for col in columns.columns}
108110
elif isinstance(columns, ibis.Column):

0 commit comments

Comments
 (0)