Skip to content

Commit 81e4d64

Browse files
fix: Fix bug with DataFrame.agg for string values (#1870)
1 parent 1c45ccb commit 81e4d64

File tree

3 files changed

+82
-17
lines changed

3 files changed

+82
-17
lines changed

bigframes/core/blocks.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,7 +2030,7 @@ def _generate_resample_label(
20302030
return block.set_index([resample_label_id])
20312031

20322032
def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
2033-
dtype = None
2033+
input_dtypes = []
20342034
input_columns: list[Optional[str]] = []
20352035
for uvalue in utils.index_as_tuples(stack_labels):
20362036
label_to_match = (*col_label, *uvalue)
@@ -2040,15 +2040,18 @@ def _create_stack_column(self, col_label: typing.Tuple, stack_labels: pd.Index):
20402040
matching_ids = self.label_to_col_id.get(label_to_match, [])
20412041
input_id = matching_ids[0] if len(matching_ids) > 0 else None
20422042
if input_id:
2043-
if dtype and dtype != self._column_type(input_id):
2044-
raise NotImplementedError(
2045-
"Cannot stack columns with non-matching dtypes."
2046-
)
2047-
else:
2048-
dtype = self._column_type(input_id)
2043+
input_dtypes.append(self._column_type(input_id))
20492044
input_columns.append(input_id)
20502045
# Input column i is the first one that
2051-
return tuple(input_columns), dtype or pd.Float64Dtype()
2046+
if len(input_dtypes) > 0:
2047+
output_dtype = bigframes.dtypes.lcd_type(*input_dtypes)
2048+
if output_dtype is None:
2049+
raise NotImplementedError(
2050+
"Cannot stack columns with non-matching dtypes."
2051+
)
2052+
else:
2053+
output_dtype = pd.Float64Dtype()
2054+
return tuple(input_columns), output_dtype
20522055

20532056
def _column_type(self, col_id: str) -> bigframes.dtypes.Dtype:
20542057
col_offset = self.value_columns.index(col_id)

bigframes/dataframe.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,14 +3004,44 @@ def agg(
30043004
if utils.is_dict_like(func):
30053005
# Must check dict-like first because dictionaries are list-like
30063006
# according to Pandas.
3007-
agg_cols = []
3008-
for col_label, agg_func in func.items():
3009-
agg_cols.append(self[col_label].agg(agg_func))
3010-
3011-
from bigframes.core.reshape import api as reshape
3012-
3013-
return reshape.concat(agg_cols, axis=1)
30143007

3008+
aggs = []
3009+
labels = []
3010+
funcnames = []
3011+
for col_label, agg_func in func.items():
3012+
agg_func_list = agg_func if utils.is_list_like(agg_func) else [agg_func]
3013+
col_id = self._block.resolve_label_exact(col_label)
3014+
if col_id is None:
3015+
raise KeyError(f"Column {col_label} does not exist")
3016+
for agg_func in agg_func_list:
3017+
agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func))
3018+
agg_expr = (
3019+
ex.UnaryAggregation(agg_op, ex.deref(col_id))
3020+
if isinstance(agg_op, agg_ops.UnaryAggregateOp)
3021+
else ex.NullaryAggregation(agg_op)
3022+
)
3023+
aggs.append(agg_expr)
3024+
labels.append(col_label)
3025+
funcnames.append(agg_func)
3026+
3027+
# if any list in dict values, format output differently
3028+
if any(utils.is_list_like(v) for v in func.values()):
3029+
new_index, _ = self.columns.reindex(labels)
3030+
new_index = utils.combine_indices(new_index, pandas.Index(funcnames))
3031+
agg_block, _ = self._block.aggregate(
3032+
aggregations=aggs, column_labels=new_index
3033+
)
3034+
return DataFrame(agg_block).stack().droplevel(0, axis="index")
3035+
else:
3036+
new_index, _ = self.columns.reindex(labels)
3037+
agg_block, _ = self._block.aggregate(
3038+
aggregations=aggs, column_labels=new_index
3039+
)
3040+
return bigframes.series.Series(
3041+
agg_block.transpose(
3042+
single_row_mode=True, original_row_index=pandas.Index([None])
3043+
)
3044+
)
30153045
elif utils.is_list_like(func):
30163046
aggregations = [agg_ops.lookup_agg_func(f) for f in func]
30173047

@@ -3027,7 +3057,7 @@ def agg(
30273057
)
30283058
)
30293059

3030-
else:
3060+
else: # function name string
30313061
return bigframes.series.Series(
30323062
self._block.aggregate_all_and_stack(
30333063
agg_ops.lookup_agg_func(typing.cast(str, func))

tests/system/small/test_dataframe.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5538,7 +5538,7 @@ def test_astype_invalid_type_fail(scalars_dfs):
55385538
bf_df.astype(123)
55395539

55405540

5541-
def test_agg_with_dict(scalars_dfs):
5541+
def test_agg_with_dict_lists(scalars_dfs):
55425542
bf_df, pd_df = scalars_dfs
55435543
agg_funcs = {
55445544
"int64_too": ["min", "max"],
@@ -5553,6 +5553,38 @@ def test_agg_with_dict(scalars_dfs):
55535553
)
55545554

55555555

5556+
def test_agg_with_dict_list_and_str(scalars_dfs):
5557+
bf_df, pd_df = scalars_dfs
5558+
agg_funcs = {
5559+
"int64_too": ["min", "max"],
5560+
"int64_col": "sum",
5561+
}
5562+
5563+
bf_result = bf_df.agg(agg_funcs).to_pandas()
5564+
pd_result = pd_df.agg(agg_funcs)
5565+
5566+
pd.testing.assert_frame_equal(
5567+
bf_result, pd_result, check_dtype=False, check_index_type=False
5568+
)
5569+
5570+
5571+
def test_agg_with_dict_strs(scalars_dfs):
5572+
bf_df, pd_df = scalars_dfs
5573+
agg_funcs = {
5574+
"int64_too": "min",
5575+
"int64_col": "sum",
5576+
"float64_col": "max",
5577+
}
5578+
5579+
bf_result = bf_df.agg(agg_funcs).to_pandas()
5580+
pd_result = pd_df.agg(agg_funcs)
5581+
pd_result.index = pd_result.index.astype("string[pyarrow]")
5582+
5583+
pd.testing.assert_series_equal(
5584+
bf_result, pd_result, check_dtype=False, check_index_type=False
5585+
)
5586+
5587+
55565588
def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs):
55575589
bf_df, _ = scalars_dfs
55585590
agg_funcs = {

0 commit comments

Comments
 (0)