Skip to content

Commit 19f2ba2

Browse files
committed
work on Polars join
1 parent 6ee5e33 commit 19f2ba2

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

data_algebra/polars_model.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -884,12 +884,18 @@ def _natural_join_step(self, op: data_algebra.data_ops_types.OperatorPlatform, *
884884
how = op.jointype.lower()
885885
if how == "full":
886886
how = "outer"
887-
coalesce_columns = (
888-
set(op.sources[0].columns_produced()).intersection(op.sources[1].columns_produced())
889-
- set(op.on_a))
890887
if how != "right":
888+
coalesce_columns = (
889+
set(op.sources[0].columns_produced()).intersection(op.sources[1].columns_produced())
890+
- set(op.on_a))
891+
orphan_keys = [c for c in op.on_b if c not in set(op.on_a)]
892+
input_right = inputs[1]
893+
if len(orphan_keys) > 0:
894+
input_right = input_right.with_columns([
895+
pl.col(c).alias(f"{c}_da_join_tmp_key") for c in orphan_keys
896+
])
891897
res = inputs[0].join(
892-
inputs[1],
898+
input_right,
893899
left_on=op.on_a,
894900
right_on=op.on_b,
895901
how=how,
@@ -903,10 +909,21 @@ def _natural_join_step(self, op: data_algebra.data_ops_types.OperatorPlatform, *
903909
.alias(c)
904910
for c in coalesce_columns
905911
])
912+
if len(orphan_keys) > 0:
913+
res = res.rename({f"{c}_da_join_tmp_key": c for c in orphan_keys})
906914
else:
907915
# simulate right join with left join
916+
coalesce_columns = (
917+
set(op.sources[0].columns_produced()).intersection(op.sources[1].columns_produced())
918+
- set(op.on_b))
919+
orphan_keys = [c for c in op.on_a if c not in set(op.on_b)]
920+
input_right = inputs[0]
921+
if len(orphan_keys) > 0:
922+
input_right = input_right.with_columns([
923+
pl.col(c).alias(f"{c}_da_join_tmp_key") for c in orphan_keys
924+
])
908925
res = inputs[1].join(
909-
inputs[0],
926+
input_right,
910927
left_on=op.on_b,
911928
right_on=op.on_a,
912929
how="left",
@@ -920,6 +937,8 @@ def _natural_join_step(self, op: data_algebra.data_ops_types.OperatorPlatform, *
920937
.alias(c)
921938
for c in coalesce_columns
922939
])
940+
if len(orphan_keys) > 0:
941+
res = res.rename({f"{c}_da_join_tmp_key": c for c in orphan_keys})
923942
res = res.select(op.columns_produced())
924943
return res
925944

tests/test_expression_expectations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_expression_expectations_1():
4848
valid_for_empty=False,
4949
empty_produces_empty=False,
5050
models_to_skip=to_skip,
51-
try_on_Polars=False, # TODO: turn this on
51+
try_on_Polars=False, # TODO: complete coverage, and turn this on
5252
)
5353
for op, op_class, exp, ops, expect in u_results:
5454
# re-run, but don't check value

tests/test_idioms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_idiom_cross_join():
247247
)
248248
data_algebra.test_util.check_transform(
249249
ops=ops, data={table_name_d: d, table_name_e: e}, expect=expect,
250-
try_on_Polars=False, # TODO: turn this on
250+
try_on_Polars=False, # TODO: get empty case to match and turn this on
251251
)
252252

253253

tests/test_join_conditions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,4 @@ def test_join_conditions_on_join():
6060
assert data_algebra.test_util.equivalent_frames(res, expect)
6161
data_algebra.test_util.check_transform(
6262
ops=ops, data={"d1": d1, "d2": d2}, expect=expect,
63-
try_on_Polars=False, # TODO: turn this on, to get it to work we have to copy columns with name changes in join conditions
6463
)

0 commit comments

Comments
 (0)