Skip to content

Commit 47906c1

Browse files
authored
Merge pull request rapidsai#21221 from rapidsai/main
Forward-merge main into pandas3
2 parents 2d26afa + fcc552e commit 47906c1

File tree

8 files changed

+243
-299
lines changed

8 files changed

+243
-299
lines changed

python/cudf/cudf/core/dataframe.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,19 @@ def _setitem_tuple_arg(self, key, value):
438438

439439

440440
class _DataFrameAtIndexer(_DataFrameLocIndexer):
441-
pass
441+
@_performance_tracking
442+
def __getitem__(self, key):
443+
indexing_utils.validate_scalar_key(
444+
key, "Invalid call for scalar access (getting)!"
445+
)
446+
return super().__getitem__(key)
447+
448+
@_performance_tracking
449+
def __setitem__(self, key, value):
450+
indexing_utils.validate_scalar_key(
451+
key, "Invalid call for scalar access (getting)!"
452+
)
453+
return super().__setitem__(key, value)
442454

443455

444456
class _DataFrameIlocIndexer(_DataFrameIndexer):
@@ -508,7 +520,19 @@ def _setitem_tuple_arg(self, key, value):
508520

509521

510522
class _DataFrameiAtIndexer(_DataFrameIlocIndexer):
511-
pass
523+
@_performance_tracking
524+
def __getitem__(self, key):
525+
indexing_utils.validate_scalar_key(
526+
key, "iAt based indexing can only have integer indexers"
527+
)
528+
return super().__getitem__(key)
529+
530+
@_performance_tracking
531+
def __setitem__(self, key, value):
532+
indexing_utils.validate_scalar_key(
533+
key, "iAt based indexing can only have integer indexers"
534+
)
535+
return super().__setitem__(key, value)
512536

513537

514538
@_performance_tracking

python/cudf/cudf/core/indexing_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cudf.api.types import (
1515
_is_scalar_or_zero_d_array,
1616
is_integer,
17+
is_list_like,
1718
)
1819
from cudf.core.column.column import as_column
1920
from cudf.core.copy_types import BooleanMask, GatherMap
@@ -69,6 +70,30 @@ class ScalarIndexer:
6970
)
7071

7172

73+
def validate_scalar_key(key: Any, error_msg: str) -> None:
74+
"""Validate that key contains only scalar values for .at/.iat indexers.
75+
76+
Parameters
77+
----------
78+
key : Any
79+
The key to validate
80+
error_msg : str
81+
The error message to raise if validation fails
82+
83+
Raises
84+
------
85+
ValueError
86+
If the key contains list-like indexers
87+
"""
88+
if not isinstance(key, tuple):
89+
if is_list_like(key):
90+
raise ValueError(error_msg)
91+
else:
92+
for k in key:
93+
if is_list_like(k):
94+
raise ValueError(error_msg)
95+
96+
7297
# Helpers for code-sharing between loc and iloc paths
7398
def expand_key(
7499
key: Any, frame: DataFrame | Series, method_type: Literal["iloc", "loc"]

python/cudf/cudf/core/series.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,19 @@ def __setitem__(self, key, value):
255255

256256

257257
class _SeriesiAtIndexer(_SeriesIlocIndexer):
258-
pass
258+
@_performance_tracking
259+
def __getitem__(self, key):
260+
indexing_utils.validate_scalar_key(
261+
key, "iAt based indexing can only have integer indexers"
262+
)
263+
return super().__getitem__(key)
264+
265+
@_performance_tracking
266+
def __setitem__(self, key, value):
267+
indexing_utils.validate_scalar_key(
268+
key, "iAt based indexing can only have integer indexers"
269+
)
270+
return super().__setitem__(key, value)
259271

260272

261273
class _SeriesLocIndexer(_FrameIndexer):
@@ -379,7 +391,19 @@ def _loc_to_iloc(self, arg):
379391

380392

381393
class _SeriesAtIndexer(_SeriesLocIndexer):
382-
pass
394+
@_performance_tracking
395+
def __getitem__(self, key):
396+
indexing_utils.validate_scalar_key(
397+
key, "Invalid call for scalar access (getting)!"
398+
)
399+
return super().__getitem__(key)
400+
401+
@_performance_tracking
402+
def __setitem__(self, key, value):
403+
indexing_utils.validate_scalar_key(
404+
key, "Invalid call for scalar access (getting)!"
405+
)
406+
return super().__setitem__(key, value)
383407

384408

385409
class Series(SingleColumnFrame, IndexedFrame):

python/cudf/cudf/pandas/_wrappers/pandas.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# with this module https://github.com/rapidsai/cudf/issues/14521#issue-2015198786
1919
import pyarrow.dataset as ds # noqa: F401
2020
from pandas._testing import at, getitem, iat, iloc, loc, setitem
21+
from pandas.compat._optional import import_optional_dependency
2122
from pandas.tseries.holiday import (
2223
AbstractHolidayCalendar as pd_AbstractHolidayCalendar,
2324
EasterMonday as pd_EasterMonday,
@@ -314,6 +315,20 @@ def _DataFrame_columns(self):
314315
return result
315316

316317

318+
def _to_xarray(self):
319+
# Call xarray conversion functions directly with self (the proxy object).
320+
# We must pass the proxy (self), not the slow pandas object, because xarray
321+
# does isinstance checks against pd.MultiIndex and pd.api.extensions.ExtensionArray.
322+
# After cudf.pandas.install(), these refer to proxy classes. The slow object
323+
# contains real pandas types that don't pass isinstance checks against the proxy
324+
# classes.
325+
xr = import_optional_dependency("xarray")
326+
if self.ndim == 1:
327+
return xr.DataArray.from_series(self)
328+
else:
329+
return xr.Dataset.from_dataframe(self)
330+
331+
317332
DataFrame = make_final_proxy_type(
318333
"DataFrame",
319334
cudf.DataFrame,
@@ -346,6 +361,7 @@ def _DataFrame_columns(self):
346361
"flags": _FastSlowAttribute("flags", private=True),
347362
"memory_usage": _FastSlowAttribute("memory_usage"),
348363
"__sizeof__": _FastSlowAttribute("__sizeof__"),
364+
"to_xarray": _to_xarray,
349365
},
350366
)
351367

@@ -419,6 +435,7 @@ def _argsort(self, *args, **kwargs):
419435
"_accessors": set(),
420436
"dtype": property(_Series_dtype),
421437
"argsort": _argsort,
438+
"to_xarray": _to_xarray,
422439
"attrs": _FastSlowAttribute("attrs"),
423440
"_mgr": _FastSlowAttribute("_mgr", private=True),
424441
"array": _FastSlowAttribute("array", private=True),
@@ -537,6 +554,7 @@ def Index__setattr__(self, name, value):
537554
fast_to_slow=lambda fast: fast,
538555
slow_to_fast=lambda slow: slow,
539556
additional_attributes={
557+
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
540558
"__from_arrow__": _FastSlowAttribute("__from_arrow__"),
541559
"__hash__": _FastSlowAttribute("__hash__"),
542560
"pyarrow_dtype": _FastSlowAttribute("pyarrow_dtype"),
@@ -575,6 +593,10 @@ def Index__setattr__(self, name, value):
575593
pd.Categorical,
576594
fast_to_slow=_Unusable(),
577595
slow_to_fast=_Unusable(),
596+
bases=(pd.api.extensions.ExtensionArray,),
597+
additional_attributes={
598+
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
599+
},
578600
)
579601

580602
CategoricalDtype = make_final_proxy_type(
@@ -614,7 +636,9 @@ def Index__setattr__(self, name, value):
614636
pd.arrays.DatetimeArray,
615637
fast_to_slow=_Unusable(),
616638
slow_to_fast=_Unusable(),
639+
bases=(pd.api.extensions.ExtensionArray,),
617640
additional_attributes={
641+
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
618642
"_data": _FastSlowAttribute("_data", private=True),
619643
"_mask": _FastSlowAttribute("_mask", private=True),
620644
},
@@ -687,7 +711,9 @@ def Index__setattr__(self, name, value):
687711
pd.arrays.TimedeltaArray,
688712
fast_to_slow=_Unusable(),
689713
slow_to_fast=_Unusable(),
714+
bases=(pd.api.extensions.ExtensionArray,),
690715
additional_attributes={
716+
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
691717
"_data": _FastSlowAttribute("_data", private=True),
692718
"_mask": _FastSlowAttribute("_mask", private=True),
693719
},
@@ -717,6 +743,7 @@ def Index__setattr__(self, name, value):
717743
pd.arrays.PeriodArray,
718744
fast_to_slow=_Unusable(),
719745
slow_to_fast=_Unusable(),
746+
bases=(pd.api.extensions.ExtensionArray,),
720747
additional_attributes={
721748
"_data": _FastSlowAttribute("_data", private=True),
722749
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -799,11 +826,13 @@ def Index__setattr__(self, name, value):
799826
pd.arrays.StringArray,
800827
fast_to_slow=_Unusable(),
801828
slow_to_fast=_Unusable(),
829+
bases=(pd.api.extensions.ExtensionArray,),
802830
additional_attributes={
803831
"_data": _FastSlowAttribute("_data", private=True),
804832
"_mask": _FastSlowAttribute("_mask", private=True),
805833
"__array__": _FastSlowAttribute("__array__"),
806834
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
835+
"__arrow_array__": _FastSlowAttribute("__arrow_array__"),
807836
},
808837
)
809838

@@ -835,6 +864,7 @@ def Index__setattr__(self, name, value):
835864
pd.core.arrays.string_arrow.ArrowStringArray,
836865
fast_to_slow=_Unusable(),
837866
slow_to_fast=_Unusable(),
867+
bases=(pd.api.extensions.ExtensionArray,),
838868
additional_attributes={
839869
"_pa_array": _FastSlowAttribute("_pa_array", private=True),
840870
"__array__": _FastSlowAttribute("__array__", private=True),
@@ -844,6 +874,7 @@ def Index__setattr__(self, name, value):
844874
"__abs__": _FastSlowAttribute("__abs__"),
845875
"__contains__": _FastSlowAttribute("__contains__"),
846876
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
877+
"__arrow_array__": _FastSlowAttribute("__arrow_array__"),
847878
},
848879
)
849880

@@ -874,6 +905,7 @@ def Index__setattr__(self, name, value):
874905
pd.arrays.BooleanArray,
875906
fast_to_slow=_Unusable(),
876907
slow_to_fast=_Unusable(),
908+
bases=(pd.api.extensions.ExtensionArray,),
877909
additional_attributes={
878910
"_data": _FastSlowAttribute("_data", private=True),
879911
"_mask": _FastSlowAttribute("_mask", private=True),
@@ -898,6 +930,7 @@ def Index__setattr__(self, name, value):
898930
pd.arrays.IntegerArray,
899931
fast_to_slow=_Unusable(),
900932
slow_to_fast=_Unusable(),
933+
bases=(pd.api.extensions.ExtensionArray,),
901934
additional_attributes={
902935
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
903936
"_data": _FastSlowAttribute("_data", private=True),
@@ -1018,7 +1051,9 @@ def Index__setattr__(self, name, value):
10181051
pd.arrays.IntervalArray,
10191052
fast_to_slow=_Unusable(),
10201053
slow_to_fast=_Unusable(),
1054+
bases=(pd.api.extensions.ExtensionArray,),
10211055
additional_attributes={
1056+
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
10221057
"_data": _FastSlowAttribute("_data", private=True),
10231058
"_mask": _FastSlowAttribute("_mask", private=True),
10241059
},
@@ -1052,6 +1087,7 @@ def Index__setattr__(self, name, value):
10521087
pd.arrays.FloatingArray,
10531088
fast_to_slow=_Unusable(),
10541089
slow_to_fast=_Unusable(),
1090+
bases=(pd.api.extensions.ExtensionArray,),
10551091
additional_attributes={
10561092
"__array_ufunc__": _FastSlowAttribute("__array_ufunc__"),
10571093
"_data": _FastSlowAttribute("_data", private=True),
@@ -1081,6 +1117,7 @@ def Index__setattr__(self, name, value):
10811117
},
10821118
)
10831119

1120+
10841121
SeriesGroupBy = make_intermediate_proxy_type(
10851122
"SeriesGroupBy",
10861123
cudf.core.groupby.groupby.SeriesGroupBy,

0 commit comments

Comments
 (0)