Skip to content

Commit bacb2b5

Browse files
author
Xuye (Chris) Qin
authored
Avoid iterative tiling for df.loc[:, fields] (#2685)
1 parent a446eaa commit bacb2b5

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

mars/dataframe/indexing/index_lib.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -242,17 +242,12 @@ def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> Non
242242
index_value = [tileable.index_value, tileable.columns_value][input_axis]
243243

244244
# check if chunks have unknown shape
245-
check = False
246-
if index_value.has_value():
247-
# index_value has value,
248-
check = True
249-
elif self._slice_all(index_info.raw_index):
250-
# if slice on all data
251-
check = True
252-
253-
if check:
254-
if any(np.isnan(ns) for ns in tileable.nsplits[input_axis]):
255-
yield []
245+
if (
246+
not self._slice_all(index_info.raw_index)
247+
and index_value.has_value()
248+
and any(np.isnan(ns) for ns in tileable.nsplits[input_axis])
249+
): # pragma: no cover
250+
yield []
256251

257252
def set_chunk_index_info(
258253
cls,
@@ -297,6 +292,27 @@ def set_chunk_index_info(
297292

298293
chunk_index_info.set(ChunkIndexAxisInfo(**kw))
299294

295+
def _process_slice_all_index(
296+
self,
297+
tileable: Tileable,
298+
index_info: IndexInfo,
299+
input_axis: int,
300+
context: IndexHandlerContext,
301+
) -> None:
302+
index_to_info = context.chunk_index_to_info.copy()
303+
for chunk_index, chunk_index_info in index_to_info.items():
304+
i = chunk_index[input_axis]
305+
size = tileable.nsplits[input_axis][i]
306+
self.set_chunk_index_info(
307+
context,
308+
index_info,
309+
chunk_index,
310+
chunk_index_info,
311+
i,
312+
slice(None),
313+
size,
314+
)
315+
300316
def _process_has_value_index(
301317
self,
302318
tileable: Tileable,
@@ -306,17 +322,14 @@ def _process_has_value_index(
306322
context: IndexHandlerContext,
307323
) -> None:
308324
pd_index = index_value.to_pandas()
309-
if self._slice_all(index_info.raw_index):
310-
slc = slice(None)
311-
else:
312-
# turn label-based slice into position-based slice
313-
start, end = pd_index.slice_locs(
314-
index_info.raw_index.start,
315-
index_info.raw_index.stop,
316-
index_info.raw_index.step,
317-
kind="loc",
318-
)
319-
slc = slice(start, end, index_info.raw_index.step)
325+
# turn label-based slice into position-based slice
326+
start, end = pd_index.slice_locs(
327+
index_info.raw_index.start,
328+
index_info.raw_index.stop,
329+
index_info.raw_index.step,
330+
kind="loc",
331+
)
332+
slc = slice(start, end, index_info.raw_index.step)
320333

321334
cum_nsplit = [0] + np.cumsum(tileable.nsplits[index_info.input_axis]).tolist()
322335
# split position-based slice into chunk slices
@@ -379,7 +392,9 @@ def process(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
379392
else:
380393
index_value = [tileable.index_value, tileable.columns_value][input_axis]
381394

382-
if index_value.has_value() or self._slice_all(index_info.raw_index):
395+
if self._slice_all(index_info.raw_index):
396+
self._process_slice_all_index(tileable, index_info, input_axis, context)
397+
elif index_value.has_value():
383398
self._process_has_value_index(
384399
tileable, index_info, index_value, input_axis, context
385400
)
@@ -829,10 +844,13 @@ def parse(self, raw_index, context: IndexHandlerContext) -> IndexInfo:
829844
def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
830845
tileable = context.tileable
831846
op = context.op
832-
if has_unknown_shape(tileable):
833-
yield []
834847

835848
input_axis = index_info.input_axis
849+
850+
# check unknown shape
851+
if any(np.isnan(s) for s in tileable.nsplits[input_axis]):
852+
yield []
853+
836854
if tileable.ndim == 2:
837855
index_value = [tileable.index_value, tileable.columns_value][input_axis]
838856
else:

mars/dataframe/indexing/tests/test_indexing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,12 @@ def test_dataframe_loc():
645645
for loc_chunk, chunk in zip(tiled_loc_df.chunks, tiled_df.chunks):
646646
assert loc_chunk.index_value.key == chunk.index_value.key
647647

648+
# test loc on filtered df
649+
df2 = df[df["x"] < 1]
650+
loc_df = df2.loc[:, ["y", "x"]]
651+
tiled_loc_df = tile(loc_df)
652+
assert len(tiled_loc_df.chunks) == 3
653+
648654

649655
def test_loc_use_iloc():
650656
raw = pd.DataFrame([[1, 3, 3], [4, 2, 6], [7, 8, 9]], columns=["x", "y", "z"])

mars/dataframe/indexing/tests/test_indexing_execution.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ def test_loc_getitem(setup):
312312
result = df.execute().fetch()
313313
expected = raw2.loc[:, "b"]
314314
pd.testing.assert_series_equal(result, expected)
315+
df = df2.loc[:, ["b", "a"]]
316+
result = df.execute().fetch()
317+
expected = raw2.loc[:, ["b", "a"]]
318+
pd.testing.assert_frame_equal(result, expected)
315319

316320
# 'b' is non-unique
317321
df = df3.loc[:, "b"]
@@ -336,6 +340,11 @@ def test_loc_getitem(setup):
336340
result = df.execute().fetch()
337341
expected = raw2.loc[[3, 0, 1], ["c", "a", "d"]]
338342
pd.testing.assert_frame_equal(result, expected)
343+
df = df2[df2["a"] < 10]
344+
df = df.loc[[3, 0, 1], ["c", "a", "d"]]
345+
result = df.execute().fetch()
346+
expected = raw2.loc[[3, 0, 1], ["c", "a", "d"]]
347+
pd.testing.assert_frame_equal(result, expected)
339348

340349
# label-based fancy index, asc sorted
341350
df = df2.loc[[0, 1, 3], ["a", "c", "d"]]

0 commit comments

Comments
 (0)