Skip to content

Commit 346579d

Browse files
committed
chore: improve a few things in _util.py
1 parent 79aab8b commit 346579d

File tree

4 files changed

+20
-16
lines changed

4 files changed

+20
-16
lines changed

mismo/_util.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,19 @@ def to_ibis_values(x):
7474

7575

7676
def cases(
77-
*case_result_pairs: tuple[ir.BooleanValue, ir.Value],
77+
first_branch: tuple[ir.BooleanValue, ir.Value],
78+
*other_branches: tuple[ir.BooleanValue, ir.Value],
7879
else_: ir.Value | None = None,
7980
) -> ir.Value:
80-
"""A more concise way to write a case statement."""
81+
"""A compat wrapper to support ibis<10.0.0, which did not have ibis.cases()."""
8182
try:
8283
# ibis.cases() was added in ibis 10.0.0
8384
cases = getattr(ibis, "cases")
84-
return cases(*case_result_pairs, else_=else_)
85+
return cases(first_branch, *other_branches, else_=else_)
8586
except AttributeError:
8687
builder = ibis.case()
87-
for case, result in case_result_pairs:
88+
builder = builder.when(*first_branch)
89+
for case, result in other_branches:
8890
builder = builder.when(case, result)
8991
return builder.else_(else_).end()
9092

@@ -140,17 +142,17 @@ def bind_one(
140142
return vals[0]
141143

142144

143-
def ensure_ibis(
145+
def ensure_val(
144146
val: Any, type: str | dt.DataType | None = None
145147
) -> ir.Value | ibis.Deferred:
146-
"""Ensure that `val` is an ibis expression."""
147-
if isinstance(val, ir.Expr) or isinstance(val, ibis.Deferred):
148+
"""Ensure that `val` is an ibis Value or Deferred."""
149+
if isinstance(val, ibis.Value) or isinstance(val, ibis.Deferred):
148150
return val
149151
return ibis.literal(val, type=type)
150152

151153

152154
def get_name(x) -> str:
153-
"""Find a suitable string representation of `x` to use as a blocker name."""
155+
"""Find a suitable string representation of `x`."""
154156
if isinstance(x, Deferred):
155157
return x.__repr__()
156158
try:
@@ -303,8 +305,8 @@ def optional_import(pip_name: str):
303305
"""
304306
Raises a more helpful ImportError when an optional dep is missing.
305307
306-
with optional_import():
307-
import some_optional_dep
308+
with optional_import("scikit-learn"):
309+
import sklearn
308310
"""
309311
try:
310312
yield
@@ -486,7 +488,9 @@ def _warn():
486488
f()
487489

488490

489-
def check_schemas_equal(a: ibis.Schema | ibis.Table, b: ibis.Schema | ibis.Table):
491+
def check_schemas_equal(
492+
a: ibis.Schema | ibis.Table, b: ibis.Schema | ibis.Table, /
493+
) -> None:
490494
if isinstance(a, ibis.Table):
491495
a = a.schema()
492496
if isinstance(b, ibis.Table):

mismo/text/_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def ngrams(string: ir.StringValue, n: int) -> ir.ArrayValue:
6363
"""
6464
if n < 1:
6565
raise ValueError("n must be greater than 0")
66-
string = _util.ensure_ibis(string, "string")
66+
string = _util.ensure_val(string, "string")
6767
pattern = "." * n
6868
# if you just do _re_extract_all("abcdef", "..."), you get ["abc", "def"].
6969
# So to get the "bcd" and the "cde", we need to offset the string

mismo/text/_similarity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def double_metaphone(s: ir.StringValue) -> ir.ArrayValue[ir.StringValue]:
2424
>>> double_metaphone(None).execute() is None
2525
True
2626
"""
27-
s = _util.ensure_ibis(s, "string")
27+
s = _util.ensure_val(s, "string")
2828
return _dm_udf(s)
2929

3030

@@ -103,8 +103,8 @@ def damerau_levenshtein_ratio(
103103

104104

105105
def _dist_ratio(s1, s2, dist):
106-
s1 = _util.ensure_ibis(s1, "string")
107-
s2 = _util.ensure_ibis(s2, "string")
106+
s1 = _util.ensure_val(s1, "string")
107+
s2 = _util.ensure_val(s2, "string")
108108
lenmax = ibis.greatest(s1.length(), s2.length())
109109
return (lenmax - dist(s1, s2)) / lenmax
110110

mismo/text/_strings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def norm_whitespace(texts: ir.StringValue) -> ir.StringValue:
1010
"""
1111
Strip leading/trailing whitespace, replace multiple whitespace with a single space.
1212
"""
13-
texts = _util.ensure_ibis(texts, "string")
13+
texts = _util.ensure_val(texts, "string")
1414
return texts.strip().re_replace(r"\s+", " ")
1515

1616

0 commit comments

Comments
 (0)