Skip to content

Commit 9f6b857

Browse files
authored
Fix DataFrame getitem when exists duplicate columns (#2581)
1 parent cc9d34f commit 9f6b857

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

mars/dataframe/indexing/getitem.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -287,18 +287,17 @@ def _set_inputs(self, inputs):
287287
def __call__(self, df):
288288
if self.col_names is not None:
289289
# if col_names is a list, return a DataFrame, else return a Series
290-
if isinstance(self._col_names, list):
291-
dtypes = df.dtypes[self._col_names]
292-
columns = parse_index(pd.Index(self._col_names), store_data=True)
290+
dtype = df.dtypes[self._col_names]
291+
if isinstance(dtype, pd.Series):
292+
columns = parse_index(dtype.index, store_data=True)
293293
return self.new_dataframe(
294294
[df],
295-
shape=(df.shape[0], len(self._col_names)),
296-
dtypes=dtypes,
295+
shape=(df.shape[0], len(dtype)),
296+
dtypes=dtype,
297297
index_value=df.index_value,
298298
columns_value=columns,
299299
)
300300
else:
301-
dtype = df.dtypes[self._col_names]
302301
return self.new_series(
303302
[df],
304303
shape=(df.shape[0],),
@@ -439,8 +438,8 @@ def tile_with_columns(cls, op):
439438
in_df = op.inputs[0]
440439
out_df = op.outputs[0]
441440
col_names = op.col_names
442-
if not isinstance(col_names, list):
443-
column_index = calc_columns_index(col_names, in_df)
441+
if not isinstance(out_df, DATAFRAME_TYPE):
442+
column_index = calc_columns_index(col_names, in_df)[0]
444443
out_chunks = []
445444
dtype = in_df.dtypes[col_names]
446445
for i in range(in_df.chunk_shape[0]):
@@ -471,7 +470,10 @@ def tile_with_columns(cls, op):
471470
# When chunk columns are ['c1', 'c2', 'c3'], ['c4', 'c5'],
472471
# selected columns are ['c2', 'c3', 'c4', 'c2'], `column_splits` will be
473472
# [(['c2', 'c3'], 0), ('c4', 1), ('c2', 0)].
473+
if not isinstance(col_names, _list_like_types):
474+
col_names = [col_names]
474475
selected_index = [calc_columns_index(col, in_df) for col in col_names]
476+
selected_index = list(itertools.chain.from_iterable(selected_index))
475477
condition = np.where(np.diff(selected_index))[0] + 1
476478
column_splits = np.split(col_names, condition)
477479
column_indexes = np.split(selected_index, condition)
@@ -482,19 +484,21 @@ def tile_with_columns(cls, op):
482484
zip(column_splits, column_indexes)
483485
):
484486
dtypes = in_df.dtypes[columns]
485-
column_nsplits.append(len(columns))
487+
column_nsplits.append(len(dtypes))
486488
for j in range(in_df.chunk_shape[0]):
487489
c = in_df.cix[(j, column_idx[0])]
488490
index_op = DataFrameIndex(
489491
col_names=list(columns), output_types=[OutputType.dataframe]
490492
)
491493
out_chunk = index_op.new_chunk(
492494
[c],
493-
shape=(c.shape[0], len(columns)),
495+
shape=(c.shape[0], len(dtypes)),
494496
index=(j, i),
495497
dtypes=dtypes,
496498
index_value=c.index_value,
497-
columns_value=parse_index(pd.Index(columns), store_data=True),
499+
columns_value=parse_index(
500+
pd.Index(dtypes.index), store_data=True
501+
),
498502
)
499503
out_chunks[j].append(out_chunk)
500504
out_chunks = [item for cl in out_chunks for item in cl]
@@ -525,8 +529,6 @@ def execute(cls, ctx, op):
525529
mask = op.mask
526530
if hasattr(mask, "reindex_like"):
527531
mask = mask.reindex_like(df).fillna(False)
528-
if mask.ndim == 2:
529-
mask = mask[df.columns.tolist()]
530532
ctx[op.outputs[0].key] = df[mask]
531533

532534
@classmethod

mars/dataframe/indexing/tests/test_indexing_execution.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ def test_dataframe_getitem(setup):
414414
df7 = df[1:7:2]
415415
pd.testing.assert_frame_equal(df7.execute().fetch(), data[1:7:2])
416416

417+
df8 = df[["c1", "c1"]]["c1"]
418+
pd.testing.assert_frame_equal(df8.execute().fetch(), data[["c1", "c1"]]["c1"])
419+
417420
series3 = df["c1"][0]
418421
assert series3.execute().fetch() == data["c1"][0]
419422

mars/dataframe/indexing/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ def calc_columns_index(column_name, df):
2424
:return: chunk index on the columns axis
2525
"""
2626
column_nsplits = df.nsplits[1]
27-
column_loc = df.columns_value.to_pandas().get_loc(column_name)
28-
return np.searchsorted(np.cumsum(column_nsplits), column_loc + 1)
27+
# if has duplicate columns, will return multiple values
28+
columns = df.columns_value.to_pandas().to_numpy()
29+
column_locs = (columns == column_name).nonzero()[0]
30+
31+
return [
32+
np.searchsorted(np.cumsum(column_nsplits), column_loc + 1)
33+
for column_loc in column_locs
34+
]
2935

3036

3137
def convert_labels_into_positions(pandas_index, labels):

0 commit comments

Comments
 (0)