Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import traceback
import typing
from typing import (
Any,
Callable,
Dict,
Hashable,
Expand Down Expand Up @@ -91,6 +92,7 @@
import bigframes.session

SingleItemValue = Union[bigframes.series.Series, int, float, str, Callable]
MultiItemValue = Union["DataFrame", Sequence[int | float | str | Callable]]

LevelType = typing.Hashable
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
Expand Down Expand Up @@ -884,8 +886,13 @@ def __delitem__(self, key: str):
df = self.drop(columns=[key])
self._set_block(df._get_block())

def __setitem__(self, key: str, value: SingleItemValue):
df = self._assign_single_item(key, value)
def __setitem__(
self, key: str | list[str], value: SingleItemValue | MultiItemValue
):
if isinstance(key, list):
df = self._assign_multi_items(key, value)
else:
df = self._assign_single_item(key, value)
self._set_block(df._get_block())

__setitem__.__doc__ = inspect.getdoc(vendored_pandas_frame.DataFrame.__setitem__)
Expand Down Expand Up @@ -2212,7 +2219,7 @@ def assign(self, **kwargs) -> DataFrame:
def _assign_single_item(
self,
k: str,
v: SingleItemValue,
v: SingleItemValue | MultiItemValue,
) -> DataFrame:
if isinstance(v, bigframes.series.Series):
return self._assign_series_join_on_index(k, v)
Expand All @@ -2230,7 +2237,33 @@ def _assign_single_item(
elif utils.is_list_like(v):
return self._assign_single_item_listlike(k, v)
else:
return self._assign_scalar(k, v)
return self._assign_scalar(k, v) # type: ignore

def _assign_multi_items(
self,
k: list[str],
v: SingleItemValue | MultiItemValue,
) -> DataFrame:
value_sources: Sequence[Any] = []
if isinstance(v, DataFrame):
value_sources = [v[col] for col in v.columns]
elif isinstance(v, bigframes.series.Series):
# For behavior consistency with Pandas.
raise ValueError("Columns must be same length as key")
elif isinstance(v, Sequence):
value_sources = v
else:
# We assign the same scalar value to all target columns.
value_sources = [v] * len(k)

if len(value_sources) != len(k):
raise ValueError("Columns must be same length as key")

# Repeatedly assign columns in order.
result = self._assign_single_item(k[0], value_sources[0])
for target, source in zip(k[1:], value_sources[1:]):
result = result._assign_single_item(target, source)
return result

def _assign_single_item_listlike(self, k: str, v: Sequence) -> DataFrame:
given_rows = len(v)
Expand Down
61 changes: 61 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,67 @@ def test_assign_new_column_w_setitem_list_error(scalars_dfs):
bf_df["new_col"] = [1, 2, 3]


@pytest.mark.parametrize(
("key", "value"),
[
pytest.param(["int64_col", "int64_too"], 1, id="scalar_to_existing_column"),
pytest.param(
["int64_col", "int64_too"], [1, 2], id="sequence_to_existing_column"
),
pytest.param(
["int64_col", "new_col"], [1, 2], id="sequence_to_partial_new_column"
),
pytest.param(
["new_col", "new_col_too"], [1, 2], id="sequence_to_full_new_column"
),
],
)
def test_setitem_multicolumn_with_literals(scalars_dfs, key, value):
scalars_df, scalars_pandas_df = scalars_dfs
bf_result = scalars_df.copy()
pd_result = scalars_pandas_df.copy()

bf_result[key] = value
pd_result[key] = value

pd.testing.assert_frame_equal(pd_result, bf_result.to_pandas(), check_dtype=False)


def test_setitem_multicolumn_with_literals_different_lengths_raise_error(scalars_dfs):
scalars_df, _ = scalars_dfs
bf_result = scalars_df.copy()

with pytest.raises(ValueError):
bf_result[["int64_col", "int64_too"]] = [1]


def test_setitem_multicolumn_with_dataframes(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
bf_result = scalars_df.copy()
pd_result = scalars_pandas_df.copy()

bf_result[["int64_col", "int64_too"]] = bf_result[["int64_too", "int64_col"]] / 2
pd_result[["int64_col", "int64_too"]] = pd_result[["int64_too", "int64_col"]] / 2

pd.testing.assert_frame_equal(pd_result, bf_result.to_pandas(), check_dtype=False)


def test_setitem_multicolumn_with_dataframes_series_on_rhs_raise_error(scalars_dfs):
scalars_df, _ = scalars_dfs
bf_result = scalars_df.copy()

with pytest.raises(ValueError):
bf_result[["int64_col", "int64_too"]] = bf_result["int64_col"] / 2


def test_setitem_multicolumn_with_dataframes_different_lengths_raise_error(scalars_dfs):
scalars_df, _ = scalars_dfs
bf_result = scalars_df.copy()

with pytest.raises(ValueError):
bf_result[["int64_col"]] = bf_result[["int64_col", "int64_too"]] / 2


def test_assign_existing_column(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
kwargs = {"int64_col": 2}
Expand Down
34 changes: 33 additions & 1 deletion third_party/bigframes_vendored/pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7626,11 +7626,43 @@ def __setitem__(self, key, value):
<BLANKLINE>
[3 rows x 5 columns]

You can assign a scalar to multiple columns.

>>> df[["age", "new_age"]] = 25
>>> df
name age location country new_age
0 alpha 25 WA USA 25
1 beta 25 NY USA 25
2 gamma 25 CA USA 25
<BLANKLINE>
[3 rows x 5 columns]

You can use a sequence of scalars for assignment of multiple columns:

>>> df[["age", "is_happy"]] = [20, True]
>>> df
name age location country new_age is_happy
0 alpha 20 WA USA 25 True
1 beta 20 NY USA 25 True
2 gamma 20 CA USA 25 True
<BLANKLINE>
[3 rows x 6 columns]

You can use a dataframe for assignment of multiple columns:
>>> df[["age", "new_age"]] = df[["new_age", "age"]]
>>> df
name age location country new_age is_happy
0 alpha 25 WA USA 20 True
1 beta 25 NY USA 20 True
2 gamma 25 CA USA 20 True
<BLANKLINE>
[3 rows x 6 columns]

Args:
key (column index):
It can be a new column to be inserted, or an existing column to
be modified.
value (scalar or Series):
value (scalar, Sequence, DataFrame, or Series):
Value to be assigned to the column
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)