Skip to content

Commit 186e8d3

Browse files
committed
tests for merge in place
1 parent dcaa489 commit 186e8d3

File tree

4 files changed

+50
-11
lines changed

4 files changed

+50
-11
lines changed

modin/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
AsyncReadMode,
2020
AutoSwitchBackend,
2121
Backend,
22+
BackendMergeCastInPlace,
2223
BenchmarkMode,
2324
CIAWSAccessKeyID,
2425
CIAWSSecretAccessKey,
@@ -78,6 +79,7 @@
7879
"GpuCount",
7980
"Memory",
8081
"Backend",
82+
"BackendMergeCastInPlace",
8183
"Execution",
8284
"AutoSwitchBackend",
8385
"ShowBackendSwitchProgress",

modin/config/envvars.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,31 @@ def disable(cls) -> None:
13851385
cls.put(False)
13861386

13871387

1388+
class BackendMergeCastInPlace(EnvironmentVariable, type=bool):
1389+
"""
1390+
Whether to cast a DataFrame in-place when performing a merge when using hybrid mode.
1391+
1392+
This flag modifies the behavior of a cast performed on operations involving more
1393+
than one type of query compiler. If enabled the actual cast will be performed in-place
1394+
and the input DataFrame will have a new backend. If disabled the original DataFrame
1395+
will remain on the same underlying engine.
1396+
1397+
"""
1398+
1399+
varname = "MODIN_BACKEND_MERGE_CAST_IN_PLACE"
1400+
default = True
1401+
1402+
@classmethod
1403+
def enable(cls) -> None:
1404+
"""Enable casting in place when performing a merge operation betwen two different compilers."""
1405+
cls.put(True)
1406+
1407+
@classmethod
1408+
def disable(cls) -> None:
1409+
"""Disable casting in place when performing a merge operation betwen two different compilers."""
1410+
cls.put(False)
1411+
1412+
13881413
class DynamicPartitioning(EnvironmentVariable, type=bool):
13891414
"""
13901415
Set to true to use Modin's dynamic-partitioning implementation where possible.

modin/core/storage_formats/pandas/query_compiler_caster.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pandas.core.indexes.frozen import FrozenList
3232
from typing_extensions import Self
3333

34-
from modin.config import AutoSwitchBackend, Backend
34+
from modin.config import AutoSwitchBackend, Backend, BackendMergeCastInPlace
3535
from modin.config import context as config_context
3636
from modin.core.storage_formats.base.query_compiler import (
3737
BaseQueryCompiler,
@@ -1120,12 +1120,18 @@ def cast_to_qc(arg):
11201120
and arg.get_backend() != result_backend
11211121
):
11221122
return arg
1123-
arg.set_backend(
1124-
result_backend,
1125-
switch_operation=f"{_normalize_class_name(class_of_wrapped_fn)}.{name}",
1126-
inplace=True,
1127-
)
1128-
cast = arg
1123+
if BackendMergeCastInPlace.get():
1124+
arg.set_backend(
1125+
result_backend,
1126+
switch_operation=f"{_normalize_class_name(class_of_wrapped_fn)}.{name}",
1127+
inplace=True,
1128+
)
1129+
cast = arg
1130+
else:
1131+
cast = arg.set_backend(
1132+
result_backend,
1133+
switch_operation=f"{_normalize_class_name(class_of_wrapped_fn)}.{name}",
1134+
)
11291135
inplace_update_trackers.append(
11301136
InplaceUpdateTracker(
11311137
input_castable=arg,
@@ -1157,7 +1163,9 @@ def cast_to_qc(arg):
11571163
original_qc,
11581164
new_castable,
11591165
) in inplace_update_trackers:
1160-
new_castable._copy_into(original_castable)
1166+
new_qc = new_castable._get_query_compiler()
1167+
if BackendMergeCastInPlace.get() or original_qc is not new_qc:
1168+
new_castable._copy_into(original_castable)
11611169

11621170
return _maybe_switch_backend_post_op(
11631171
result,

modin/tests/pandas/native_df_interoperability/test_compiler_caster.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,13 @@ def test_merge_in_place(default_df, lazy_df, cloud_df):
664664
# Both arguments now have the same qc type
665665
assert type(lazy_df) is type(default_df)
666666

667-
df = cloud_df.merge(lazy_df)
668-
assert type(df) is type(cloud_df)
669-
assert type(lazy_df) is type(cloud_df)
667+
with config_context(BackendMergeCastInPlace=False):
668+
lazy_df = lazy_df.move_to("Lazy")
669+
cloud_df = cloud_df.move_to("Cloud")
670+
df = cloud_df.merge(lazy_df)
671+
assert type(df) is type(cloud_df)
672+
assert lazy_df.get_backend() == "Lazy"
673+
assert cloud_df.get_backend() == "Cloud"
670674

671675

672676
def test_information_asymmetry(default_df, cloud_df, eager_df, lazy_df):

0 commit comments

Comments
 (0)