Skip to content

Commit e083147

Browse files
committed
Implement Index.putmask
1 parent f853afd commit e083147

File tree

6 files changed

+160
-3
lines changed

6 files changed

+160
-3
lines changed

python/docs/source/reference/pyspark.pandas/indexing.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Modifying and computations
8383
Index.min
8484
Index.max
8585
Index.map
86+
Index.putmask
8687
Index.rename
8788
Index.repeat
8889
Index.take

python/pyspark/pandas/indexes/base.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,16 @@
3636
from pandas.api.types import CategoricalDtype, is_hashable
3737
from pandas._libs import lib
3838

39+
from pyspark.sql.functions import pandas_udf
3940
from pyspark.sql import functions as F, Column
40-
from pyspark.sql.types import FractionalType, IntegralType, TimestampType, TimestampNTZType
41+
from pyspark.sql.types import (
42+
FractionalType,
43+
IntegralType,
44+
TimestampType,
45+
TimestampNTZType,
46+
BooleanType,
47+
StringType,
48+
)
4149

4250
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
4351
from pyspark.pandas._typing import Dtype, Label, Name, Scalar
@@ -46,6 +54,7 @@
4654
from pyspark.pandas.frame import DataFrame
4755
from pyspark.pandas.missing.indexes import MissingPandasLikeIndex
4856
from pyspark.pandas.series import Series, first_series
57+
from pyspark.pandas.typedef import infer_pd_series_spark_type
4958
from pyspark.pandas.spark import functions as SF
5059
from pyspark.pandas.spark.accessors import SparkIndexMethods
5160
from pyspark.pandas.utils import (
@@ -1927,6 +1936,112 @@ def argmin(self) -> int:
19271936
.first()[0]
19281937
)
19291938

1939+
def putmask(
1940+
self, mask: Union[Series, "Index", List, Tuple], value: Union[Series, "Index", List, Tuple]
1941+
) -> "Index":
1942+
"""
1943+
Return a new Index of the values set with the mask.
1944+
.. note:: this API can be pretty expensive since it is based on
1945+
a global sequence internally.
1946+
Parameters
1947+
----------
1948+
mask : array-like
1949+
Boolean mask array. It has to be the same shape as the index.
1950+
value : array-like
1951+
Value to put into the index where mask is True.
1952+
Returns
1953+
-------
1954+
Index
1955+
Examples
1956+
-------
1957+
>>> psidx = ps.Index([1, 2, 3, 4, 5])
1958+
>>> psidx
1959+
Int64Index([1, 2, 3, 4, 5], dtype='int64')
1960+
>>> psidx.putmask(psidx > 3, 100).sort_values()
1961+
Int64Index([1, 2, 3, 100, 100], dtype='int64')
1962+
>>> psidx.putmask(psidx > 3, ps.Index([100, 200, 300, 400, 500])).sort_values()
1963+
Int64Index([1, 2, 3, 400, 500], dtype='int64')
1964+
"""
1965+
scol_name = self._internal.index_spark_column_names[0]
1966+
sdf = self._internal.spark_frame.select(self.spark.column)
1967+
1968+
dist_sequence_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
1969+
sdf = InternalFrame.attach_distributed_sequence_column(
1970+
sdf, column_name=dist_sequence_col_name
1971+
)
1972+
1973+
replace_col = verify_temp_column_name(sdf, "__replace_column__")
1974+
masking_col = verify_temp_column_name(sdf, "__masking_column__")
1975+
1976+
if isinstance(value, (list, tuple, Index, Series)):
1977+
if isinstance(value, (list, tuple)):
1978+
pd_value = pd.Series(value)
1979+
elif isinstance(value, (Series, Index)):
1980+
pd_value = value.to_pandas()
1981+
1982+
if self.size != pd_value.size:
1983+
# TODO: We can't support different size of value for now.
1984+
raise ValueError("value and data must be the same size")
1985+
1986+
replace_return_type = infer_pd_series_spark_type(pd_value, pd_value.dtype)
1987+
1988+
@pandas_udf(
1989+
returnType=replace_return_type if replace_return_type else StringType()
1990+
) # type: ignore
1991+
def replace_pandas_udf(sequence: pd.Series) -> pd.Series:
1992+
return pd_value[sequence]
1993+
1994+
sdf = sdf.withColumn(replace_col, replace_pandas_udf(dist_sequence_col_name))
1995+
else:
1996+
sdf = sdf.withColumn(replace_col, F.lit(value))
1997+
1998+
if isinstance(mask, (list, tuple)):
1999+
pandas_mask = pd.Series(mask)
2000+
elif isinstance(mask, (Index, Series)):
2001+
pandas_mask = mask.to_pandas()
2002+
else:
2003+
raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__))
2004+
2005+
if self.size != pandas_mask.size:
2006+
raise ValueError("mask and data must be the same size")
2007+
2008+
@pandas_udf(returnType=BooleanType()) # type: ignore
2009+
def masking_pandas_udf(sequence: pd.Series) -> pd.Series:
2010+
return pandas_mask[sequence]
2011+
2012+
sdf = sdf.withColumn(masking_col, masking_pandas_udf(dist_sequence_col_name))
2013+
2014+
# spark_frame here looks like below
2015+
# +-------------------------------+-----------------+------------------+------------------+
2016+
# |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__|
2017+
# +-------------------------------+-----------------+------------------+------------------+
2018+
# | 0| a| 100| true|
2019+
# | 3| d| 400| false|
2020+
# | 1| b| 200| true|
2021+
# | 2| c| 300| false|
2022+
# | 4| e| 500| false|
2023+
# +-------------------------------+-----------------+------------------+------------------+
2024+
2025+
cond = F.when(scol_for(sdf, masking_col), scol_for(sdf, replace_col)).otherwise(
2026+
scol_for(sdf, scol_name)
2027+
)
2028+
sdf = sdf.select(cond.alias(scol_name))
2029+
2030+
if sdf.schema[scol_name].nullable != self._internal.index_fields[0].nullable:
2031+
sdf.schema[scol_name].nullable = self._internal.index_fields[0].nullable
2032+
sdf = sdf.sql_ctx.createDataFrame(sdf.rdd, sdf.schema)
2033+
2034+
internal = InternalFrame(
2035+
spark_frame=sdf,
2036+
index_spark_columns=[
2037+
scol_for(sdf, col) for col in self._internal.index_spark_column_names
2038+
],
2039+
index_names=self._internal.index_names,
2040+
index_fields=self._internal.index_fields,
2041+
)
2042+
2043+
return DataFrame(internal).index
2044+
19302045
def set_names(
19312046
self,
19322047
names: Union[Name, List[Name]],

python/pyspark/pandas/indexes/multi.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,13 @@ def map(
12361236
) -> "Index":
12371237
return MissingPandasLikeMultiIndex.map(self, mapper, na_action)
12381238

1239+
def putmask(
1240+
self,
1241+
mask: Union[Series, "Index", List, Tuple] = None,
1242+
value: Union[Series, "Index", List, Tuple] = None,
1243+
) -> "Index":
1244+
return MissingPandasLikeMultiIndex.putmask(self, mask, value)
1245+
12391246

12401247
def _test() -> None:
12411248
import os

python/pyspark/pandas/missing/indexes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ class MissingPandasLikeIndex(object):
5353
groupby = _unsupported_function("groupby")
5454
is_ = _unsupported_function("is_")
5555
join = _unsupported_function("join")
56-
putmask = _unsupported_function("putmask")
5756
ravel = _unsupported_function("ravel")
5857
reindex = _unsupported_function("reindex")
5958
searchsorted = _unsupported_function("searchsorted")

python/pyspark/pandas/tests/indexes/test_base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2388,6 +2388,41 @@ def test_map(self):
23882388
lambda: psidx.map({1: 1, 2: 2.0, 3: "three"}),
23892389
)
23902390

2391+
def test_putmask(self):
2392+
pidx = pd.Index(["a", "b", "c", "d", "e"])
2393+
psidx = ps.from_pandas(pidx)
2394+
2395+
self.assert_eq(
2396+
psidx.putmask(psidx < "c", "k").sort_values(),
2397+
pidx.putmask(pidx < "c", "k").sort_values(),
2398+
)
2399+
self.assert_eq(
2400+
psidx.putmask(psidx < "c", ["g", "h", "i", "j", "k"]).sort_values(),
2401+
pidx.putmask(pidx < "c", ["g", "h", "i", "j", "k"]).sort_values(),
2402+
)
2403+
self.assert_eq(
2404+
psidx.putmask(psidx < "c", ("g", "h", "i", "j", "k")).sort_values(),
2405+
pidx.putmask(pidx < "c", ("g", "h", "i", "j", "k")).sort_values(),
2406+
)
2407+
self.assert_eq(
2408+
psidx.putmask(psidx < "c", ps.Index(["g", "h", "i", "j", "k"])).sort_values(),
2409+
pidx.putmask(pidx < "c", pd.Index(["g", "h", "i", "j", "k"])).sort_values(),
2410+
)
2411+
self.assert_eq(
2412+
psidx.putmask(psidx < "c", ps.Series(["g", "h", "i", "j", "k"])).sort_values(),
2413+
pidx.putmask(pidx < "c", pd.Series(["g", "h", "i", "j", "k"])).sort_values(),
2414+
)
2415+
2416+
self.assertRaises(
2417+
ValueError,
2418+
lambda: psidx.putmask(psidx < "c", ps.Series(["g", "h"])),
2419+
)
2420+
2421+
self.assertRaises(
2422+
ValueError,
2423+
lambda: psidx.putmask([True, False], ps.Series(["g", "h", "i", "j", "k"])),
2424+
)
2425+
23912426
def test_multiindex_equal_levels(self):
23922427
pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
23932428
pmidx2 = pd.MultiIndex.from_tuples([("b", "y"), ("a", "x"), ("c", "z")])

python/pyspark/pandas/typedef/typehints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def infer_pd_series_spark_type(
356356
if dtype == np.dtype("object"):
357357
if len(pser) == 0 or pser.isnull().all():
358358
return types.NullType()
359-
elif hasattr(pser.iloc[0], "__UDT__"):
359+
elif hasattr(pser, "iloc") and hasattr(pser.iloc[0], "__UDT__"):
360360
return pser.iloc[0].__UDT__
361361
else:
362362
return from_arrow_type(pa.Array.from_pandas(pser).type, prefer_timestamp_ntz)

0 commit comments

Comments
 (0)