Skip to content

Commit fd95406

Browse files
committed
work on Polars ifelse and where
1 parent 6a7d275 commit fd95406

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

data_algebra/polars_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,10 @@ def _populate_expr_impl_map() -> Dict[int, Dict[str, Callable]]:
309309
"parse_datetime": lambda x, format : x.cast(str).str.strptime(pl.Datetime, fmt=format, strict=False).cast(pl.Datetime),
310310
}
311311
impl_map_3 = {
312-
"if_else": lambda a, b, c: pl.when(a).then(b).otherwise(c),
312+
"if_else": lambda a, b, c: pl.when(a.is_null()).then(pl.lit(None)).otherwise(pl.when(a).then(b).otherwise(c)),
313313
"mapv": _mapv,
314314
"trimstr": lambda a, b, c: a.trimstr(b, c),
315-
"where": lambda a, b, c: pl.when(a).then(b).otherwise(c),
315+
"where": lambda a, b, c: pl.when(a.is_null()).then(c).otherwise(pl.when(a).then(b).otherwise(c)),
316316
}
317317
impl_map = {
318318
0: impl_map_0,
@@ -436,7 +436,7 @@ def drop_indices(self, df) -> None:
436436

437437
def bad_column_positions(self, x):
438438
"""
439-
Return vector indicating which entries are bad (null or nan) (vectorized).
439+
Return vector indicating which entries are null (vectorized).
440440
"""
441441
return x.is_null()
442442

tests/test_if_else.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,4 @@ def test_if_else_where():
139139
assert data_algebra.test_util.equivalent_frames(res_sqlite, expect)
140140
sqlite_handle.close()
141141
data_algebra.test_util.check_transform(ops=ops, data=d, expect=expect,
142-
try_on_Polars=False, # TODO: turn this on
143142
)

tests/test_polars.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -640,10 +640,10 @@ def test_is_inf_polars():
640640
ops = (
641641
data_algebra.descr(d=d)
642642
.extend({
643-
'is_inf': 'a.is_inf().if_else(1, 0)',
644-
'is_nan': 'a.is_nan().if_else(1, 0)',
645-
'is_bad': 'a.is_bad().if_else(1, 0)',
646-
'is_null': 'a.is_null().if_else(1, 0)',
643+
'is_inf': 'a.is_inf().where(1, 0)',
644+
'is_nan': 'a.is_nan().where(1, 0)',
645+
'is_bad': 'a.is_bad().where(1, 0)',
646+
'is_null': 'a.is_null().where(1, 0)',
647647
})
648648
)
649649
res_polars = ops.transform(d)

0 commit comments

Comments
 (0)