Skip to content

Commit 0df26c5

Browse files
author
Xuye (Chris) Qin
authored
Fix df.loc[:] to make sure same index_value key generated (#2643)
1 parent 3c0a4ca commit 0df26c5

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

mars/dataframe/base/rechunk.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,14 @@ def compute_rechunk(a, chunk_size):
209209
calc_sliced_size(s, chunk_slice[0]) for s in old_chunk.shape
210210
)
211211
new_index_value = indexing_index_value(
212-
old_chunk.index_value, chunk_slice[0]
212+
old_chunk.index_value, chunk_slice[0], rechunk=True
213213
)
214214
if is_dataframe:
215215
new_columns_value = indexing_index_value(
216-
old_chunk.columns_value, chunk_slice[1], store_data=True
216+
old_chunk.columns_value,
217+
chunk_slice[1],
218+
store_data=True,
219+
rechunk=True,
217220
)
218221
merge_chunk_op = DataFrameIlocGetItem(
219222
list(chunk_slice),

mars/dataframe/indexing/loc.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ...serialization.serializables import KeyField, ListField
2727
from ...tensor.datasource import asarray
2828
from ...tensor.utils import calc_sliced_size, filter_inputs
29-
from ...utils import lazy_import
29+
from ...utils import lazy_import, is_full_slice
3030
from ..core import IndexValue, DATAFRAME_TYPE
3131
from ..operands import DataFrameOperand, DataFrameOperandMixin
3232
from ..utils import parse_index
@@ -154,7 +154,13 @@ def _calc_slice_param(
154154
axis: int,
155155
) -> Dict:
156156
param = dict()
157-
if input_index_value.has_value():
157+
if is_full_slice(index):
158+
# full slice on this axis
159+
param["shape"] = inp.shape[axis]
160+
param["index_value"] = input_index_value
161+
if axis == 1:
162+
param["dtypes"] = inp.dtypes
163+
elif input_index_value.has_value():
158164
start, end = pd_index.slice_locs(
159165
index.start, index.stop, index.step, kind="loc"
160166
)

mars/dataframe/indexing/tests/test_indexing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def test_iloc_getitem():
122122
df4 = tile(df4)
123123
assert isinstance(df4, DATAFRAME_TYPE)
124124
assert isinstance(df4.op, DataFrameIlocGetItem)
125+
assert df4.index_value.key == df2.index_value.key
125126
assert df4.shape == (3, 1)
126127
assert df4.chunk_shape == (2, 1)
127128
assert df4.chunks[0].shape == (2, 1)
@@ -479,6 +480,7 @@ def test_dataframe_loc():
479480
df2.index_value.to_pandas(), df.index_value.to_pandas()
480481
)
481482
assert df2.name == "y"
483+
assert df2.index_value.key == df.index_value.key
482484

483485
df2 = tile(df2)
484486
assert len(df2.chunks) == 2

mars/dataframe/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..core import Entity, ExecutableTuple
2828
from ..lib.mmh3 import hash as mmh_hash
2929
from ..tensor.utils import dictify_chunk_size, normalize_chunk_sizes
30-
from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder
30+
from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder, is_full_slice
3131

3232
try:
3333
import pyarrow as pa
@@ -804,9 +804,13 @@ def filter_index_value(index_value, min_max, store_data=False):
804804
return parse_index(pd_index[f], store_data=store_data)
805805

806806

807-
def indexing_index_value(index_value, indexes, store_data=False):
807+
def indexing_index_value(index_value, indexes, store_data=False, rechunk=False):
808808
pd_index = index_value.to_pandas()
809-
if not index_value.has_value():
809+
# when rechunk is True, the output index shall be treated
810+
# different from the input one
811+
if not rechunk and isinstance(indexes, slice) and is_full_slice(indexes):
812+
return index_value
813+
elif not index_value.has_value():
810814
new_index_value = parse_index(pd_index, indexes, store_data=store_data)
811815
new_index_value._index_value._min_val = index_value.min_val
812816
new_index_value._index_value._min_val_close = index_value.min_val_close

mars/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,3 +1519,13 @@ def flatten_dict_to_nested_dict(flatten_dict: Dict, sep=".") -> Dict:
15191519
else:
15201520
sub_nested_dict = sub_nested_dict[sub_key]
15211521
return nested_dict
1522+
1523+
1524+
def is_full_slice(slc: Any) -> bool:
1525+
"""Check if the input is a full slice ((:) or (0:))"""
1526+
return (
1527+
isinstance(slc, slice)
1528+
and (slc.start == 0 or slc.start is None)
1529+
and slc.stop is None
1530+
and slc.step is None
1531+
)

0 commit comments

Comments
 (0)