Skip to content

Commit 6a7d275

Browse files
committed
fix and test more of Polars adapter
1 parent 3ba63c4 commit 6a7d275

File tree

7 files changed

+45
-16
lines changed

7 files changed

+45
-16
lines changed

data_algebra/data_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ def table_is_keyed_by_columns(self, table, *, column_names: Iterable[str]) -> bo
7272
:return: True if rows are uniquely keyed by values in named columns
7373
"""
7474

75+
@abc.abstractmethod
76+
def concat_rows(self, frame_list: List):
77+
"""
78+
Concatenate rows from frame_list
79+
"""
80+
81+
@abc.abstractmethod
82+
def concat_columns(self, frame_list):
83+
"""
84+
Concatenate columns from frame_list
85+
"""
86+
7587
# evaluate
7688

7789
@abc.abstractmethod

data_algebra/pandas_base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,24 @@ def bad_column_positions(self, x):
456456
self.pd.isnull(x), numpy.logical_or(numpy.isnan(x), numpy.isinf(x))
457457
)
458458
return self.pd.isnull(x)
459+
460+
def concat_rows(self, frame_list: List):
461+
"""
462+
Concatenate rows from frame_list
463+
"""
464+
frame_list = list(frame_list)
465+
assert len(frame_list) > 0
466+
if len(frame_list) == 1:
467+
return self.clean_copy(frame_list[0])
468+
res = self.pd.concat(frame_list, axis=0)
469+
return res
459470

460-
def concat_columns(self, frame_list):
471+
def concat_columns(self, frame_list: List):
461472
"""
462-
Concatinate columns from frame_list
473+
Concatenate columns from frame_list
463474
"""
464475
frame_list = list(frame_list)
465-
if len(frame_list) <= 0:
466-
return None
476+
assert len(frame_list) > 0
467477
if len(frame_list) == 1:
468478
return self.clean_copy(frame_list[0])
469479
res = self.pd.concat(frame_list, axis=1)

data_algebra/polars_model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ def _populate_expr_impl_map() -> Dict[int, Dict[str, Callable]]:
289289
"%/%": lambda a, b: a / b,
290290
"around": lambda a, b: a.round(b),
291291
"coalesce": lambda a, b: pl.when(a.is_null()).then(b).otherwise(a),
292-
"concat": lambda a, b: a.concat(b),
293292
"date_diff": lambda a, b: a.date_diff(b),
294293
"is_in": lambda a, b: a.is_in(b),
295294
"mod": lambda a, b: a % b,
@@ -367,6 +366,7 @@ def __init__(self, *, use_lazy_eval: bool = True):
367366
}
368367
self._expr_impl_map = _populate_expr_impl_map()
369368
self._impl_map_arbitrary_arity = {
369+
"concat": lambda *args: pl.concat_str(args),
370370
"fmax": lambda *args: pl.max(args),
371371
"fmin": lambda *args: pl.min(args),
372372
"maximum": lambda *args: pl.max(args),
@@ -440,9 +440,19 @@ def bad_column_positions(self, x):
440440
"""
441441
return x.is_null()
442442

443+
def concat_rows(self, frame_list: List):
444+
"""
445+
Concatenate rows from frame_list
446+
"""
447+
frame_list = list(frame_list)
448+
assert len(frame_list) > 0
449+
if len(frame_list) == 1:
450+
return frame_list[0]
451+
pl.concat(frame_list, how="vertical")
452+
443453
def concat_columns(self, frame_list):
444454
"""
445-
Concatinate columns from frame_list
455+
Concatenate columns from frame_list
446456
"""
447457
frame_list = list(frame_list)
448458
if len(frame_list) <= 0:

data_algebra/solutions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def xicor_score_variables_plan(
224224
assert isinstance(n_rep, int)
225225
record_map = RecordMap(
226226
blocks_out=RecordSpecification(
227-
control_table=data_algebra.data_model.default_data_model().pd.DataFrame(
227+
control_table=data_algebra.data_model.default_data_model().data_frame(
228228
{
229229
"variable_name": x_vars,
230230
"x": x_vars,
@@ -237,7 +237,7 @@ def xicor_score_variables_plan(
237237
),
238238
strict=False,
239239
)
240-
rep_frame = data_algebra.data_model.default_data_model().pd.DataFrame({"rep": range(n_rep)})
240+
rep_frame = data_algebra.data_model.default_data_model().data_frame({"rep": range(n_rep)})
241241
grouped_calc = (
242242
xicor_query(
243243
d
@@ -529,13 +529,12 @@ def replicate_rows_query(
529529
assert power_key_colname not in d.column_names
530530
# get a pandas namespace
531531
local_data_model = data_algebra.data_model.default_data_model()
532-
pd = local_data_model.pd
533532
# build powers of 2 until max_count is met or exceeded
534533
powers = list(range(int(numpy.ceil(numpy.log(max_count) / numpy.log(2))) + 1))
535534
# replicate each power the number of times it specifies
536-
count_frame = pd.concat(
535+
count_frame = local_data_model.concat_rows(
537536
[
538-
pd.DataFrame(
537+
local_data_model.data_frame(
539538
{
540539
power_key_colname: f"p{p}",
541540
seq_column_name: range(int(2**p)),

tests/test_braid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
def test_braid():
1313
pd = data_algebra.data_model.default_data_model().pd
1414
d_state = pd.DataFrame({
15-
't': [1, 3, 5],
15+
't': [1., 3., 5.],
1616
'state': ['a', 'b', 'c'],
1717
})
1818
d_event = pd.DataFrame({
19-
't': [1, 4],
19+
't': [1., 4.],
2020
'value': [10, 20],
2121
})
2222
ops = data_algebra.solutions.braid_data(
@@ -31,7 +31,7 @@ def test_braid():
3131
res = ops.eval({'d_state': d_state, 'd_event': d_event})
3232
# print(data_algebra.util.pandas_to_example_str(res))
3333
expect = pd.DataFrame({
34-
't': [1, 1, 3, 4, 5],
34+
't': [1., 1., 3., 4., 5.],
3535
'state': ['a', 'a', 'b', 'b', 'c'],
3636
'value': [None, 10, None, 20, None],
3737
'record_type': ['state_row', 'event_row', 'state_row', 'event_row', 'state_row'],

tests/test_expand_rows.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,4 @@ def test_replicate_rows_query():
3131
ops=ops,
3232
data={'d': d, 'rt': rt},
3333
expect=expect,
34-
try_on_Polars=False, # TODO: turn this on
3534
)

tests/test_idioms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def test_idiom_concat_op():
389389
)
390390

391391
data_algebra.test_util.check_transform(ops=ops, data=d, expect=expect,
392-
try_on_Polars=False, # TODO: turn this on
393392
)
394393

395394

0 commit comments

Comments
 (0)