Skip to content

Commit 1f758c2

Browse files
committed
[SPARK-55296][PS] Support CoW mode with pandas 3
### What changes were proposed in this pull request? Support CoW (Copy-on-Write) mode with pandas 3. ### Why are the changes needed? Pandas 3 is doing copy-on-write for everything. For example: ```py >>> pdf = pd.DataFrame( ... [[1, 2], [4, 5], [7, 8]], ... index=["cobra", "viper", "sidewinder"], ... columns=["max_speed", "shield"], ... ) >>> >>> pser1 = pdf.max_speed >>> pser2 = pdf.shield >>> >>> pdf.loc[["viper", "sidewinder"], ["max_speed", "shield"]] = 10 ``` - pandas 2 ```py >>> pdf max_speed shield cobra 1 2 viper 10 10 sidewinder 10 10 >>> pser1 cobra 1 viper 10 sidewinder 10 Name: max_speed, dtype: int64 >>> pser2 cobra 2 viper 10 sidewinder 10 Name: shield, dtype: int64 ``` - pandas 3 ```py >>> pdf max_speed shield cobra 1 2 viper 10 10 sidewinder 10 10 >>> pser1 cobra 1 viper 4 sidewinder 7 Name: max_speed, dtype: int64 >>> pser2 cobra 2 viper 5 sidewinder 8 Name: shield, dtype: int64 ``` Or for `Series`: ```py >>> pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}, index=["cobra", "viper", "sidewinder"]) >>> >>> pser = pdf.x >>> psery = pdf.y >>> >>> pser.loc[pser % 2 == 1] = -pser ``` - pandas 2 ```py >>> pdf x y cobra -1 4 viper 2 5 sidewinder -3 6 >>> pser cobra -1 viper 2 sidewinder -3 Name: x, dtype: int64 >>> psery cobra 4 viper 5 sidewinder 6 Name: y, dtype: int64 ``` - pandas 3 ```py >>> pdf x y cobra 1 4 viper 2 5 sidewinder 3 6 >>> pser cobra -1 viper 2 sidewinder -3 Name: x, dtype: int64 >>> psery cobra 4 viper 5 sidewinder 6 Name: y, dtype: int64 ``` ### Does this PR introduce _any_ user-facing change? Yes, it will behave more like pandas 3. ### How was this patch tested? Updated the related tests to make it clear, but basically the existing tests should pass. ### Was this patch authored or co-authored using generative AI tooling? Codex (GPT-5.3-Codex) Closes #54375 from ueshin/issues/SPARK-55296/cow. Authored-by: Takuya Ueshin <ueshin@databricks.com> Signed-off-by: Takuya Ueshin <ueshin@databricks.com>
1 parent 6806c8b commit 1f758c2

File tree

4 files changed

+140
-65
lines changed

4 files changed

+140
-65
lines changed

python/pyspark/pandas/indexing.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pandas.api.types import is_list_like
2828
import numpy as np
2929

30+
from pyspark.loose_version import LooseVersion
3031
from pyspark.sql import functions as F, Column as PySparkColumn
3132
from pyspark.sql.types import BooleanType, LongType, DataType
3233
from pyspark.sql.utils import is_remote
@@ -720,7 +721,13 @@ def __setitem__(self, key: Any, value: Any) -> None:
720721

721722
cond, limit, remaining_index = self._select_rows(rows_sel)
722723
missing_keys: List[Name] = []
723-
_, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys)
724+
(
725+
selected_column_labels,
726+
data_spark_columns,
727+
_,
728+
_,
729+
_,
730+
) = self._select_cols(cols_sel, missing_keys=missing_keys)
724731

725732
if cond is None:
726733
cond = F.lit(True)
@@ -737,6 +744,30 @@ def __setitem__(self, key: Any, value: Any) -> None:
737744
if isinstance(value, Series):
738745
value = value.spark.column
739746
else:
747+
if (
748+
# Only apply this behavior for pandas 3+, where CoW semantics changed.
749+
LooseVersion(pd.__version__) >= "3.0.0"
750+
# Only for multi-column assignment (single-column assignment is unaffected).
751+
and len(selected_column_labels) > 1
752+
# Column selector must be list-like (e.g. ["shield", "max_speed"]), not scalar label access.
753+
and is_list_like(cols_sel)
754+
# Excludes string/bytes (single label), tuple (e.g. MultiIndex label),
755+
# and slice selectors; keeps this narrowly on explicit column lists.
756+
and not isinstance(cols_sel, (str, bytes, tuple, slice))
757+
# Only trigger when cached/anchored Series exist on the frame,
758+
# matching the problematic case where views were materialized before assignment.
759+
and hasattr(self._psdf_or_psser, "_psseries")
760+
):
761+
selected_column_labels_set = set(selected_column_labels)
762+
selected_labels_in_internal_order = [
763+
label
764+
for label in self._internal.column_labels
765+
if label in selected_column_labels_set
766+
]
767+
if selected_column_labels != selected_labels_in_internal_order:
768+
# If requested columns are in different order than the DataFrame’s internal order,
769+
# it returns early (no-op), matching pandas 3 behavior for that edge case.
770+
return
740771
value = F.lit(value)
741772

742773
new_data_spark_columns = []

python/pyspark/pandas/series.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,10 @@ def __init__( # type: ignore[no-untyped-def]
430430
assert not copy
431431
assert fastpath is no_default
432432

433-
self._anchor = data
433+
if LooseVersion(pd.__version__) < "3.0.0":
434+
self._anchor = data
435+
else:
436+
self._anchor = DataFrame(data)
434437
self._col_label = index
435438

436439
elif isinstance(data, Series):

python/pyspark/pandas/tests/indexes/test_indexing_iloc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pandas as pd
2020

2121
from pyspark import pandas as ps
22+
from pyspark.loose_version import LooseVersion
2223
from pyspark.pandas.exceptions import SparkPandasIndexingError, SparkPandasNotImplementedError
2324
from pyspark.testing.pandasutils import PandasOnSparkTestCase
2425
from pyspark.testing.sqlutils import SQLTestUtils
@@ -180,9 +181,10 @@ def test_frame_iloc_setitem(self):
180181
)
181182
psdf = ps.from_pandas(pdf)
182183

183-
pdf.iloc[:, 0] = pdf
184-
psdf.iloc[:, 0] = psdf
185-
self.assert_eq(psdf, pdf)
184+
if LooseVersion(pd.__version__) < "3.0.0":
185+
pdf.iloc[:, 0] = pdf
186+
psdf.iloc[:, 0] = psdf
187+
self.assert_eq(psdf, pdf)
186188

187189
def test_series_iloc_setitem(self):
188190
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}, index=["cobra", "viper", "sidewinder"])

python/pyspark/pandas/tests/indexes/test_indexing_loc.py

Lines changed: 99 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pandas as pd
2121

2222
from pyspark import pandas as ps
23+
from pyspark.loose_version import LooseVersion
2324
from pyspark.testing.pandasutils import PandasOnSparkTestCase
2425
from pyspark.testing.sqlutils import SQLTestUtils
2526

@@ -240,54 +241,65 @@ def test_loc_timestamp_str(self):
240241
self.assert_eq(pdf.B.loc["2011":"2015"], psdf.B.loc["2011":"2015"])
241242

242243
def test_frame_loc_setitem(self):
244+
def check(op, check_ser, almost):
245+
pdf = pd.DataFrame(
246+
[[1, 2], [4, 5], [7, 8]],
247+
index=["cobra", "viper", "sidewinder"],
248+
columns=["max_speed", "shield"],
249+
)
250+
psdf = ps.from_pandas(pdf)
251+
252+
if check_ser:
253+
pser1 = pdf.max_speed
254+
pser2 = pdf.shield
255+
psser1 = psdf.max_speed
256+
psser2 = psdf.shield
257+
258+
op(pdf)
259+
op(psdf)
260+
261+
self.assert_eq(psdf, pdf, almost=almost)
262+
if check_ser:
263+
self.assert_eq(psser1, pser1)
264+
self.assert_eq(psser2, pser2)
265+
266+
def op0(df):
267+
df.loc[["viper", "sidewinder"], ["max_speed", "shield"]] = 10
268+
269+
def op1(df):
270+
df.loc[["viper", "sidewinder"], ["shield", "max_speed"]] = 10
271+
272+
def op2(df):
273+
df.loc[["viper", "sidewinder"], "shield"] = 50
274+
275+
def op3(df):
276+
df.loc["cobra", "max_speed"] = 30
277+
278+
def op4(df):
279+
df.loc[df.max_speed < 5, "max_speed"] = -df.max_speed
280+
281+
def op5(df):
282+
df.loc[df.max_speed < 2, "max_speed"] = -df.max_speed
283+
284+
def op6(df):
285+
df.loc[:, "min_speed"] = 0
286+
287+
for check_ser in [True, False]:
288+
for op in [op0, op1, op2, op3, op4, op5, (op6, True)]:
289+
if isinstance(op, tuple):
290+
op, almost = op
291+
else:
292+
op, almost = op, False
293+
with self.subTest(check_ser=check_ser, op=op.__name__):
294+
check(op, check_ser=check_ser, almost=almost)
295+
243296
pdf = pd.DataFrame(
244297
[[1, 2], [4, 5], [7, 8]],
245298
index=["cobra", "viper", "sidewinder"],
246299
columns=["max_speed", "shield"],
247300
)
248301
psdf = ps.from_pandas(pdf)
249302

250-
pser1 = pdf.max_speed
251-
pser2 = pdf.shield
252-
psser1 = psdf.max_speed
253-
psser2 = psdf.shield
254-
255-
pdf.loc[["viper", "sidewinder"], ["shield", "max_speed"]] = 10
256-
psdf.loc[["viper", "sidewinder"], ["shield", "max_speed"]] = 10
257-
self.assert_eq(psdf, pdf)
258-
self.assert_eq(psser1, pser1)
259-
self.assert_eq(psser2, pser2)
260-
261-
pdf.loc[["viper", "sidewinder"], "shield"] = 50
262-
psdf.loc[["viper", "sidewinder"], "shield"] = 50
263-
self.assert_eq(psdf, pdf)
264-
self.assert_eq(psser1, pser1)
265-
self.assert_eq(psser2, pser2)
266-
267-
pdf.loc["cobra", "max_speed"] = 30
268-
psdf.loc["cobra", "max_speed"] = 30
269-
self.assert_eq(psdf, pdf)
270-
self.assert_eq(psser1, pser1)
271-
self.assert_eq(psser2, pser2)
272-
273-
pdf.loc[pdf.max_speed < 5, "max_speed"] = -pdf.max_speed
274-
psdf.loc[psdf.max_speed < 5, "max_speed"] = -psdf.max_speed
275-
self.assert_eq(psdf, pdf)
276-
self.assert_eq(psser1, pser1)
277-
self.assert_eq(psser2, pser2)
278-
279-
pdf.loc[pdf.max_speed < 2, "max_speed"] = -pdf.max_speed
280-
psdf.loc[psdf.max_speed < 2, "max_speed"] = -psdf.max_speed
281-
self.assert_eq(psdf, pdf)
282-
self.assert_eq(psser1, pser1)
283-
self.assert_eq(psser2, pser2)
284-
285-
pdf.loc[:, "min_speed"] = 0
286-
psdf.loc[:, "min_speed"] = 0
287-
self.assert_eq(psdf, pdf, almost=True)
288-
self.assert_eq(psser1, pser1)
289-
self.assert_eq(psser2, pser2)
290-
291303
with self.assertRaisesRegex(ValueError, "Incompatible indexer with Series"):
292304
psdf.loc["cobra", "max_speed"] = -psdf.max_speed
293305
with self.assertRaisesRegex(ValueError, "shape mismatch"):
@@ -296,23 +308,49 @@ def test_frame_loc_setitem(self):
296308
psdf.loc[:, "max_speed"] = psdf
297309

298310
# multi-index columns
299-
columns = pd.MultiIndex.from_tuples(
300-
[("x", "max_speed"), ("x", "shield"), ("y", "min_speed")]
301-
)
302-
pdf.columns = columns
303-
psdf.columns = columns
311+
def check(op, check_ser):
312+
pdf = pd.DataFrame(
313+
[[1, 2, 0], [4, 5, 0], [7, 8, 0]],
314+
index=["cobra", "viper", "sidewinder"],
315+
columns=pd.MultiIndex.from_tuples(
316+
[("x", "max_speed"), ("x", "shield"), ("y", "min_speed")]
317+
),
318+
)
319+
psdf = ps.from_pandas(pdf)
320+
321+
if check_ser:
322+
pser1 = pdf[("x", "max_speed")]
323+
pser2 = pdf[("x", "shield")]
324+
psser1 = psdf[("x", "max_speed")]
325+
psser2 = psdf[("x", "shield")]
326+
327+
op(pdf)
328+
op(psdf)
329+
330+
self.assert_eq(psdf, pdf, almost=True)
331+
if check_ser:
332+
self.assert_eq(psser1, pser1)
333+
self.assert_eq(psser2, pser2)
334+
335+
def mop0(df):
336+
df.loc[:, ("y", "shield")] = -df[("x", "shield")]
337+
338+
def mop1(df):
339+
df.loc[:, "z"] = 100
340+
341+
for check_ser in [True, False]:
342+
for op in [mop0, mop1]:
343+
with self.subTest(check_ser=check_ser, op=op.__name__):
344+
check(op, check_ser=check_ser)
304345

305-
pdf.loc[:, ("y", "shield")] = -pdf[("x", "shield")]
306-
psdf.loc[:, ("y", "shield")] = -psdf[("x", "shield")]
307-
self.assert_eq(psdf, pdf, almost=True)
308-
self.assert_eq(psser1, pser1)
309-
self.assert_eq(psser2, pser2)
310-
311-
pdf.loc[:, "z"] = 100
312-
psdf.loc[:, "z"] = 100
313-
self.assert_eq(psdf, pdf, almost=True)
314-
self.assert_eq(psser1, pser1)
315-
self.assert_eq(psser2, pser2)
346+
pdf = pd.DataFrame(
347+
[[1, 2, 0], [4, 5, 0], [7, 8, 0]],
348+
index=["cobra", "viper", "sidewinder"],
349+
columns=pd.MultiIndex.from_tuples(
350+
[("x", "max_speed"), ("x", "shield"), ("y", "min_speed")]
351+
),
352+
)
353+
psdf = ps.from_pandas(pdf)
316354

317355
with self.assertRaisesRegex(KeyError, "Key length \\(3\\) exceeds index depth \\(2\\)"):
318356
psdf.loc[:, [("x", "max_speed", "foo")]] = -psdf[("x", "shield")]
@@ -322,9 +360,10 @@ def test_frame_loc_setitem(self):
322360
)
323361
psdf = ps.from_pandas(pdf)
324362

325-
pdf.loc[:, "max_speed"] = pdf
326-
psdf.loc[:, "max_speed"] = psdf
327-
self.assert_eq(psdf, pdf)
363+
if LooseVersion(pd.__version__) < "3.0.0":
364+
pdf.loc[:, "max_speed"] = pdf
365+
psdf.loc[:, "max_speed"] = psdf
366+
self.assert_eq(psdf, pdf)
328367

329368
def test_series_loc_setitem(self):
330369
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}, index=["cobra", "viper", "sidewinder"])

0 commit comments

Comments
 (0)