Skip to content

Commit 56efd30

Browse files
authored
Fix as_index when calling groupby-agg (#2676)
1 parent 0009146 commit 56efd30

File tree

3 files changed

+37
-15
lines changed

3 files changed

+37
-15
lines changed

mars/dataframe/groupby/aggregation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ def _get_index_levels(self, groupby, mock_index):
281281
pd_index = mock_index
282282
return 1 if not isinstance(pd_index, pd.MultiIndex) else len(pd_index.levels)
283283

284+
def _fix_as_index(self, result_index: pd.Index):
285+
# make sure if as_index=False takes effect
286+
if isinstance(result_index, pd.MultiIndex):
287+
# if MultiIndex, as_index=False definitely takes no effect
288+
self.groupby_params["as_index"] = True
289+
elif result_index.name is not None:
290+
# if not MultiIndex and agg_df.index has a name
291+
# means as_index=False takes no effect
292+
self.groupby_params["as_index"] = True
293+
284294
def _call_dataframe(self, groupby, input_df):
285295
agg_df = build_mock_agg_result(
286296
groupby, self.groupby_params, self.raw_func, **self.raw_func_kw
@@ -291,13 +301,7 @@ def _call_dataframe(self, groupby, input_df):
291301
index_value.value.should_be_monotonic = True
292302

293303
# make sure if as_index=False takes effect
294-
if isinstance(agg_df.index, pd.MultiIndex):
295-
# if MultiIndex, as_index=False definitely takes no effect
296-
self.groupby_params["as_index"] = True
297-
elif agg_df.index.name is not None:
298-
# if not MultiIndex and agg_df.index has a name
299-
# means as_index=False takes no effect
300-
self.groupby_params["as_index"] = True
304+
self._fix_as_index(agg_df.index)
301305

302306
# determine num of indices to group in intermediate steps
303307
self._index_levels = self._get_index_levels(groupby, agg_df.index)
@@ -315,6 +319,10 @@ def _call_series(self, groupby, in_series):
315319
agg_result = build_mock_agg_result(
316320
groupby, self.groupby_params, self.raw_func, **self.raw_func_kw
317321
)
322+
323+
# make sure if as_index=False takes effect
324+
self._fix_as_index(agg_result.index)
325+
318326
index_value = parse_index(
319327
agg_result.index, groupby.key, groupby.index_value.key
320328
)

mars/dataframe/groupby/tests/test_groupby_execution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,13 @@ def test_groupby_getitem(setup):
315315
expected.sort_values(["c1", "c2"]).reset_index(drop=True),
316316
)
317317

318+
r = mdf.groupby(["c1", "c2"], as_index=False)["c3"].agg(["sum"])
319+
expected = raw.groupby(["c1", "c2"], as_index=False)["c3"].agg(["sum"])
320+
pd.testing.assert_frame_equal(
321+
r.execute().fetch().sort_values(["c1", "c2"]),
322+
expected.sort_values(["c1", "c2"]),
323+
)
324+
318325

319326
def test_dataframe_groupby_agg(setup):
320327
agg_funs = [

mars/dataframe/reduction/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -860,12 +860,12 @@ def add_function(self, func, ndim, cols=None, func_name=None):
860860
self._output_key_to_post_steps[step.output_key] = step
861861
self._update_col_dict(self._output_key_to_post_cols, step.output_key, cols)
862862

863-
@functools.lru_cache(100)
864-
def _compile_expr_function(self, py_src):
863+
def _compile_expr_function(self, py_src: str, local_consts: dict):
865864
from ... import tensor, dataframe
866865

867866
result_store = dict()
868-
global_vars = globals()
867+
global_vars = globals().copy()
868+
global_vars.update(local_consts)
869869
global_vars.update(dict(mt=tensor, md=dataframe, array=np.array, nan=np.nan))
870870
exec(
871871
py_src, global_vars, result_store
@@ -989,24 +989,24 @@ def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps:
989989
assert len(initial_inputs) == 1
990990
input_key = initial_inputs[0].key
991991

992-
func_str, _ = self._generate_function_str(t.inputs[0])
992+
func_str, _, local_consts = self._generate_function_str(t.inputs[0])
993993
pre_funcs.append(
994994
ReductionPreStep(
995995
input_key,
996996
agg_input_key,
997997
None,
998-
self._compile_expr_function(func_str),
998+
self._compile_expr_function(func_str, local_consts),
999999
)
10001000
)
10011001
# collect function output after agg
1002-
func_str, input_keys = self._generate_function_str(func_ret)
1002+
func_str, input_keys, local_consts = self._generate_function_str(func_ret)
10031003
post_funcs.append(
10041004
ReductionPostStep(
10051005
input_keys,
10061006
func_ret.key,
10071007
func_name,
10081008
None,
1009-
self._compile_expr_function(func_str),
1009+
self._compile_expr_function(func_str, local_consts),
10101010
)
10111011
)
10121012
if len(_func_compile_cache) > 100: # pragma: no cover
@@ -1034,6 +1034,7 @@ def _generate_function_str(self, out_tileable):
10341034

10351035
input_key_to_var = OrderedDict()
10361036
local_key_to_var = dict()
1037+
local_consts_to_val = dict()
10371038
ref_counts = dict()
10381039
ref_visited = set()
10391040
local_lines = []
@@ -1086,7 +1087,12 @@ def _interpret_var(v):
10861087
# get representation for variables
10871088
if hasattr(v, "key"):
10881089
return keys_to_vars[v.key]
1089-
return v
1090+
elif isinstance(v, (int, bool, str, bytes, np.integer, np.bool_)):
1091+
return repr(v)
1092+
else:
1093+
const_name = f"_const_{len(local_consts_to_val)}"
1094+
local_consts_to_val[const_name] = v
1095+
return const_name
10901096

10911097
func_name = func_name_raw = getattr(t.op, "_func_name", None)
10921098
rfunc_name = getattr(t.op, "_rfunc_name", func_name)
@@ -1187,6 +1193,7 @@ def _interpret_var(v):
11871193
f" {lines_str}\n"
11881194
f" return {local_key_to_var[out_tileable.key]}",
11891195
list(input_key_to_var.keys()),
1196+
local_consts_to_val,
11901197
)
11911198

11921199
def compile(self) -> ReductionSteps:

0 commit comments

Comments
 (0)