diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 85b8245272..b2947f7493 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -26,6 +26,7 @@ import traceback import typing from typing import ( + Any, Callable, Dict, Hashable, @@ -91,6 +92,7 @@ import bigframes.session SingleItemValue = Union[bigframes.series.Series, int, float, str, Callable] + MultiItemValue = Union["DataFrame", Sequence[int | float | str | Callable]] LevelType = typing.Hashable LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]] @@ -884,8 +886,13 @@ def __delitem__(self, key: str): df = self.drop(columns=[key]) self._set_block(df._get_block()) - def __setitem__(self, key: str, value: SingleItemValue): - df = self._assign_single_item(key, value) + def __setitem__( + self, key: str | list[str], value: SingleItemValue | MultiItemValue + ): + if isinstance(key, list): + df = self._assign_multi_items(key, value) + else: + df = self._assign_single_item(key, value) self._set_block(df._get_block()) __setitem__.__doc__ = inspect.getdoc(vendored_pandas_frame.DataFrame.__setitem__) @@ -2212,7 +2219,7 @@ def assign(self, **kwargs) -> DataFrame: def _assign_single_item( self, k: str, - v: SingleItemValue, + v: SingleItemValue | MultiItemValue, ) -> DataFrame: if isinstance(v, bigframes.series.Series): return self._assign_series_join_on_index(k, v) @@ -2230,7 +2237,33 @@ def _assign_single_item( elif utils.is_list_like(v): return self._assign_single_item_listlike(k, v) else: - return self._assign_scalar(k, v) + return self._assign_scalar(k, v) # type: ignore + + def _assign_multi_items( + self, + k: list[str], + v: SingleItemValue | MultiItemValue, + ) -> DataFrame: + value_sources: Sequence[Any] = [] + if isinstance(v, DataFrame): + value_sources = [v[col] for col in v.columns] + elif isinstance(v, bigframes.series.Series): + # For behavior consistency with Pandas. + raise ValueError("Columns must be same length as key") + elif isinstance(v, Sequence): + value_sources = v + else: + # We assign the same scalar value to all target columns. + value_sources = [v] * len(k) + + if len(value_sources) != len(k): + raise ValueError("Columns must be same length as key") + + # Repeatedly assign columns in order. + result = self._assign_single_item(k[0], value_sources[0]) + for target, source in zip(k[1:], value_sources[1:]): + result = result._assign_single_item(target, source) + return result def _assign_single_item_listlike(self, k: str, v: Sequence) -> DataFrame: given_rows = len(v) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 51f4674ba4..c7f9627531 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -1138,6 +1138,67 @@ def test_assign_new_column_w_setitem_list_error(scalars_dfs): bf_df["new_col"] = [1, 2, 3] +@pytest.mark.parametrize( + ("key", "value"), + [ + pytest.param(["int64_col", "int64_too"], 1, id="scalar_to_existing_column"), + pytest.param( + ["int64_col", "int64_too"], [1, 2], id="sequence_to_existing_column" + ), + pytest.param( + ["int64_col", "new_col"], [1, 2], id="sequence_to_partial_new_column" + ), + pytest.param( + ["new_col", "new_col_too"], [1, 2], id="sequence_to_full_new_column" + ), + ], +) +def test_setitem_multicolumn_with_literals(scalars_dfs, key, value): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.copy() + pd_result = scalars_pandas_df.copy() + + bf_result[key] = value + pd_result[key] = value + + pd.testing.assert_frame_equal(pd_result, bf_result.to_pandas(), check_dtype=False) + + +def test_setitem_multicolumn_with_literals_different_lengths_raise_error(scalars_dfs): + scalars_df, _ = scalars_dfs + bf_result = scalars_df.copy() + + with pytest.raises(ValueError): + bf_result[["int64_col", "int64_too"]] = [1] + + +def test_setitem_multicolumn_with_dataframes(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.copy() + pd_result = scalars_pandas_df.copy() + + bf_result[["int64_col", "int64_too"]] = bf_result[["int64_too", "int64_col"]] / 2 + pd_result[["int64_col", "int64_too"]] = pd_result[["int64_too", "int64_col"]] / 2 + + pd.testing.assert_frame_equal(pd_result, bf_result.to_pandas(), check_dtype=False) + + +def test_setitem_multicolumn_with_dataframes_series_on_rhs_raise_error(scalars_dfs): + scalars_df, _ = scalars_dfs + bf_result = scalars_df.copy() + + with pytest.raises(ValueError): + bf_result[["int64_col", "int64_too"]] = bf_result["int64_col"] / 2 + + +def test_setitem_multicolumn_with_dataframes_different_lengths_raise_error(scalars_dfs): + scalars_df, _ = scalars_dfs + bf_result = scalars_df.copy() + + with pytest.raises(ValueError): + bf_result[["int64_col"]] = bf_result[["int64_col", "int64_too"]] / 2 + + def test_assign_existing_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs kwargs = {"int64_col": 2} diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 44ca558070..953ece9beb 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -7626,11 +7626,43 @@ def __setitem__(self, key, value): [3 rows x 5 columns] + You can assign a scalar to multiple columns. + + >>> df[["age", "new_age"]] = 25 + >>> df + name age location country new_age + 0 alpha 25 WA USA 25 + 1 beta 25 NY USA 25 + 2 gamma 25 CA USA 25 + + [3 rows x 5 columns] + + You can use a sequence of scalars for assignment of multiple columns: + + >>> df[["age", "is_happy"]] = [20, True] + >>> df + name age location country new_age is_happy + 0 alpha 20 WA USA 25 True + 1 beta 20 NY USA 25 True + 2 gamma 20 CA USA 25 True + + [3 rows x 6 columns] + + You can use a dataframe for assignment of multiple columns: + >>> df[["age", "new_age"]] = df[["new_age", "age"]] + >>> df + name age location country new_age is_happy + 0 alpha 25 WA USA 20 True + 1 beta 25 NY USA 20 True + 2 gamma 25 CA USA 20 True + + [3 rows x 6 columns] + Args: key (column index): It can be a new column to be inserted, or an existing column to be modified. - value (scalar or Series): + value (scalar, Sequence, DataFrame, or Series): Value to be assigned to the column """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)