Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion python/pyspark/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pandas.api.types import is_list_like
import numpy as np

from pyspark.loose_version import LooseVersion
from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import BooleanType, LongType, DataType
from pyspark.sql.utils import is_remote
Expand Down Expand Up @@ -720,7 +721,13 @@ def __setitem__(self, key: Any, value: Any) -> None:

cond, limit, remaining_index = self._select_rows(rows_sel)
missing_keys: List[Name] = []
_, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys)
(
selected_column_labels,
data_spark_columns,
_,
_,
_,
) = self._select_cols(cols_sel, missing_keys=missing_keys)

if cond is None:
cond = F.lit(True)
Expand All @@ -737,6 +744,30 @@ def __setitem__(self, key: Any, value: Any) -> None:
if isinstance(value, Series):
value = value.spark.column
else:
if (
# Only apply this behavior for pandas 3+, where CoW semantics changed.
LooseVersion(pd.__version__) >= "3.0.0"
# Only for multi-column assignment (single-column assignment is unaffected).
and len(selected_column_labels) > 1
# Column selector must be list-like (e.g. ["shield", "max_speed"]), not scalar label access.
and is_list_like(cols_sel)
# Excludes string/bytes (single label), tuple (e.g. MultiIndex label),
# and slice selectors; keeps this narrowly on explicit column lists.
and not isinstance(cols_sel, (str, bytes, tuple, slice))
# Only trigger when cached/anchored Series exist on the frame,
# matching the problematic case where views were materialized before assignment.
and hasattr(self._psdf_or_psser, "_psseries")
):
selected_column_labels_set = set(selected_column_labels)
selected_labels_in_internal_order = [
label
for label in self._internal.column_labels
if label in selected_column_labels_set
]
if selected_column_labels != selected_labels_in_internal_order:
# If requested columns are in different order than the DataFrame’s internal order,
# it returns early (no-op), matching pandas 3 behavior for that edge case.
return
value = F.lit(value)

new_data_spark_columns = []
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,10 @@ def __init__( # type: ignore[no-untyped-def]
assert not copy
assert fastpath is no_default

self._anchor = data
if LooseVersion(pd.__version__) < "3.0.0":
self._anchor = data
else:
self._anchor = DataFrame(data)
self._col_label = index

elif isinstance(data, Series):
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/pandas/tests/indexes/test_indexing_iloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd

from pyspark import pandas as ps
from pyspark.loose_version import LooseVersion
from pyspark.pandas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
Expand Down Expand Up @@ -180,9 +181,10 @@ def test_frame_iloc_setitem(self):
)
psdf = ps.from_pandas(pdf)

pdf.iloc[:, 0] = pdf
psdf.iloc[:, 0] = psdf
self.assert_eq(psdf, pdf)
if LooseVersion(pd.__version__) < "3.0.0":
pdf.iloc[:, 0] = pdf
psdf.iloc[:, 0] = psdf
self.assert_eq(psdf, pdf)

def test_series_iloc_setitem(self):
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}, index=["cobra", "viper", "sidewinder"])
Expand Down
159 changes: 99 additions & 60 deletions python/pyspark/pandas/tests/indexes/test_indexing_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd

from pyspark import pandas as ps
from pyspark.loose_version import LooseVersion
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils

Expand Down Expand Up @@ -240,54 +241,65 @@ def test_loc_timestamp_str(self):
self.assert_eq(pdf.B.loc["2011":"2015"], psdf.B.loc["2011":"2015"])

def test_frame_loc_setitem(self):
def check(op, check_ser, almost):
pdf = pd.DataFrame(
[[1, 2], [4, 5], [7, 8]],
index=["cobra", "viper", "sidewinder"],
columns=["max_speed", "shield"],
)
psdf = ps.from_pandas(pdf)

if check_ser:
pser1 = pdf.max_speed
pser2 = pdf.shield
psser1 = psdf.max_speed
psser2 = psdf.shield

op(pdf)
op(psdf)

self.assert_eq(psdf, pdf, almost=almost)
if check_ser:
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

def op0(df):
df.loc[["viper", "sidewinder"], ["max_speed", "shield"]] = 10

def op1(df):
df.loc[["viper", "sidewinder"], ["shield", "max_speed"]] = 10

def op2(df):
df.loc[["viper", "sidewinder"], "shield"] = 50

def op3(df):
df.loc["cobra", "max_speed"] = 30

def op4(df):
df.loc[df.max_speed < 5, "max_speed"] = -df.max_speed

def op5(df):
df.loc[df.max_speed < 2, "max_speed"] = -df.max_speed

def op6(df):
df.loc[:, "min_speed"] = 0

for check_ser in [True, False]:
for op in [op0, op1, op2, op3, op4, op5, (op6, True)]:
if isinstance(op, tuple):
op, almost = op
else:
op, almost = op, False
with self.subTest(check_ser=check_ser, op=op.__name__):
check(op, check_ser=check_ser, almost=almost)

pdf = pd.DataFrame(
[[1, 2], [4, 5], [7, 8]],
index=["cobra", "viper", "sidewinder"],
columns=["max_speed", "shield"],
)
psdf = ps.from_pandas(pdf)

pser1 = pdf.max_speed
pser2 = pdf.shield
psser1 = psdf.max_speed
psser2 = psdf.shield

pdf.loc[["viper", "sidewinder"], ["shield", "max_speed"]] = 10
psdf.loc[["viper", "sidewinder"], ["shield", "max_speed"]] = 10
self.assert_eq(psdf, pdf)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

pdf.loc[["viper", "sidewinder"], "shield"] = 50
psdf.loc[["viper", "sidewinder"], "shield"] = 50
self.assert_eq(psdf, pdf)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

pdf.loc["cobra", "max_speed"] = 30
psdf.loc["cobra", "max_speed"] = 30
self.assert_eq(psdf, pdf)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

pdf.loc[pdf.max_speed < 5, "max_speed"] = -pdf.max_speed
psdf.loc[psdf.max_speed < 5, "max_speed"] = -psdf.max_speed
self.assert_eq(psdf, pdf)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

pdf.loc[pdf.max_speed < 2, "max_speed"] = -pdf.max_speed
psdf.loc[psdf.max_speed < 2, "max_speed"] = -psdf.max_speed
self.assert_eq(psdf, pdf)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

pdf.loc[:, "min_speed"] = 0
psdf.loc[:, "min_speed"] = 0
self.assert_eq(psdf, pdf, almost=True)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

with self.assertRaisesRegex(ValueError, "Incompatible indexer with Series"):
psdf.loc["cobra", "max_speed"] = -psdf.max_speed
with self.assertRaisesRegex(ValueError, "shape mismatch"):
Expand All @@ -296,23 +308,49 @@ def test_frame_loc_setitem(self):
psdf.loc[:, "max_speed"] = psdf

# multi-index columns
columns = pd.MultiIndex.from_tuples(
[("x", "max_speed"), ("x", "shield"), ("y", "min_speed")]
)
pdf.columns = columns
psdf.columns = columns
def check(op, check_ser):
pdf = pd.DataFrame(
[[1, 2, 0], [4, 5, 0], [7, 8, 0]],
index=["cobra", "viper", "sidewinder"],
columns=pd.MultiIndex.from_tuples(
[("x", "max_speed"), ("x", "shield"), ("y", "min_speed")]
),
)
psdf = ps.from_pandas(pdf)

if check_ser:
pser1 = pdf[("x", "max_speed")]
pser2 = pdf[("x", "shield")]
psser1 = psdf[("x", "max_speed")]
psser2 = psdf[("x", "shield")]

op(pdf)
op(psdf)

self.assert_eq(psdf, pdf, almost=True)
if check_ser:
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

def mop0(df):
df.loc[:, ("y", "shield")] = -df[("x", "shield")]

def mop1(df):
df.loc[:, "z"] = 100

for check_ser in [True, False]:
for op in [mop0, mop1]:
with self.subTest(check_ser=check_ser, op=op.__name__):
check(op, check_ser=check_ser)

pdf.loc[:, ("y", "shield")] = -pdf[("x", "shield")]
psdf.loc[:, ("y", "shield")] = -psdf[("x", "shield")]
self.assert_eq(psdf, pdf, almost=True)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)

pdf.loc[:, "z"] = 100
psdf.loc[:, "z"] = 100
self.assert_eq(psdf, pdf, almost=True)
self.assert_eq(psser1, pser1)
self.assert_eq(psser2, pser2)
pdf = pd.DataFrame(
[[1, 2, 0], [4, 5, 0], [7, 8, 0]],
index=["cobra", "viper", "sidewinder"],
columns=pd.MultiIndex.from_tuples(
[("x", "max_speed"), ("x", "shield"), ("y", "min_speed")]
),
)
psdf = ps.from_pandas(pdf)

with self.assertRaisesRegex(KeyError, "Key length \\(3\\) exceeds index depth \\(2\\)"):
psdf.loc[:, [("x", "max_speed", "foo")]] = -psdf[("x", "shield")]
Expand All @@ -322,9 +360,10 @@ def test_frame_loc_setitem(self):
)
psdf = ps.from_pandas(pdf)

pdf.loc[:, "max_speed"] = pdf
psdf.loc[:, "max_speed"] = psdf
self.assert_eq(psdf, pdf)
if LooseVersion(pd.__version__) < "3.0.0":
pdf.loc[:, "max_speed"] = pdf
psdf.loc[:, "max_speed"] = psdf
self.assert_eq(psdf, pdf)

def test_series_loc_setitem(self):
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}, index=["cobra", "viper", "sidewinder"])
Expand Down