Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.pandas/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Modifying and computations
Index.min
Index.max
Index.map
Index.putmask
Index.rename
Index.repeat
Index.take
Expand Down
117 changes: 116 additions & 1 deletion python/pyspark/pandas/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@
from pandas.api.types import CategoricalDtype, is_hashable
from pandas._libs import lib

from pyspark.sql.functions import pandas_udf
from pyspark.sql import functions as F, Column
from pyspark.sql.types import FractionalType, IntegralType, TimestampType, TimestampNTZType
from pyspark.sql.types import (
FractionalType,
IntegralType,
TimestampType,
TimestampNTZType,
BooleanType,
StringType,
)

from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Dtype, Label, Name, Scalar
Expand All @@ -46,6 +54,7 @@
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.missing.indexes import MissingPandasLikeIndex
from pyspark.pandas.series import Series, first_series
from pyspark.pandas.typedef import infer_pd_series_spark_type
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.spark.accessors import SparkIndexMethods
from pyspark.pandas.utils import (
Expand Down Expand Up @@ -1927,6 +1936,112 @@ def argmin(self) -> int:
.first()[0]
)

def putmask(
self, mask: Union[Series, "Index", List, Tuple], value: Union[Series, "Index", List, Tuple]
) -> "Index":
"""
Return a new Index of the values set with the mask.
.. note:: this API can be pretty expensive since it is based on
a global sequence internally.
Parameters
----------
mask : array-like
Boolean mask array. It has to be the same shape as the index.
value : array-like
Value to put into the index where mask is True.
Returns
-------
Index
Examples
-------
>>> psidx = ps.Index([1, 2, 3, 4, 5])
>>> psidx
Int64Index([1, 2, 3, 4, 5], dtype='int64')
>>> psidx.putmask(psidx > 3, 100).sort_values()
Int64Index([1, 2, 3, 100, 100], dtype='int64')
>>> psidx.putmask(psidx > 3, ps.Index([100, 200, 300, 400, 500])).sort_values()
Int64Index([1, 2, 3, 400, 500], dtype='int64')
"""
scol_name = self._internal.index_spark_column_names[0]
sdf = self._internal.spark_frame.select(self.spark.column)

dist_sequence_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
sdf = InternalFrame.attach_distributed_sequence_column(
sdf, column_name=dist_sequence_col_name
)

replace_col = verify_temp_column_name(sdf, "__replace_column__")
masking_col = verify_temp_column_name(sdf, "__masking_column__")

if isinstance(value, (list, tuple, Index, Series)):
if isinstance(value, (list, tuple)):
pd_value = pd.Series(value)
elif isinstance(value, (Series, Index)):
pd_value = value.to_pandas()

if self.size != pd_value.size:
# TODO: We can't support different size of value for now.
raise ValueError("value and data must be the same size")

replace_return_type = infer_pd_series_spark_type(pd_value, pd_value.dtype)

@pandas_udf(
returnType=replace_return_type if replace_return_type else StringType()
) # type: ignore
def replace_pandas_udf(sequence: pd.Series) -> pd.Series:
return pd_value[sequence]

sdf = sdf.withColumn(replace_col, replace_pandas_udf(dist_sequence_col_name))
else:
sdf = sdf.withColumn(replace_col, F.lit(value))

if isinstance(mask, (list, tuple)):
pandas_mask = pd.Series(mask)
elif isinstance(mask, (Index, Series)):
pandas_mask = mask.to_pandas()
else:
raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__))

if self.size != pandas_mask.size:
raise ValueError("mask and data must be the same size")

@pandas_udf(returnType=BooleanType()) # type: ignore
def masking_pandas_udf(sequence: pd.Series) -> pd.Series:
return pandas_mask[sequence]

sdf = sdf.withColumn(masking_col, masking_pandas_udf(dist_sequence_col_name))

# spark_frame here looks like below
# +-------------------------------+-----------------+------------------+------------------+
# |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__|
# +-------------------------------+-----------------+------------------+------------------+
# | 0| a| 100| true|
# | 3| d| 400| false|
# | 1| b| 200| true|
# | 2| c| 300| false|
# | 4| e| 500| false|
# +-------------------------------+-----------------+------------------+------------------+

cond = F.when(scol_for(sdf, masking_col), scol_for(sdf, replace_col)).otherwise(
scol_for(sdf, scol_name)
)
sdf = sdf.select(cond.alias(scol_name))

if sdf.schema[scol_name].nullable != self._internal.index_fields[0].nullable:
sdf.schema[scol_name].nullable = self._internal.index_fields[0].nullable
sdf = sdf.sql_ctx.createDataFrame(sdf.rdd, sdf.schema)

internal = InternalFrame(
spark_frame=sdf,
index_spark_columns=[
scol_for(sdf, col) for col in self._internal.index_spark_column_names
],
index_names=self._internal.index_names,
index_fields=self._internal.index_fields,
)

return DataFrame(internal).index

def set_names(
self,
names: Union[Name, List[Name]],
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/pandas/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,13 @@ def map(
) -> "Index":
return MissingPandasLikeMultiIndex.map(self, mapper, na_action)

def putmask(
self,
mask: Union[Series, "Index", List, Tuple] = None,
value: Union[Series, "Index", List, Tuple] = None,
) -> "Index":
return MissingPandasLikeMultiIndex.putmask(self, mask, value)


def _test() -> None:
import os
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/pandas/missing/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class MissingPandasLikeIndex(object):
groupby = _unsupported_function("groupby")
is_ = _unsupported_function("is_")
join = _unsupported_function("join")
putmask = _unsupported_function("putmask")
ravel = _unsupported_function("ravel")
reindex = _unsupported_function("reindex")
searchsorted = _unsupported_function("searchsorted")
Expand Down
35 changes: 35 additions & 0 deletions python/pyspark/pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,41 @@ def test_map(self):
lambda: psidx.map({1: 1, 2: 2.0, 3: "three"}),
)

def test_putmask(self):
pidx = pd.Index(["a", "b", "c", "d", "e"])
psidx = ps.from_pandas(pidx)

self.assert_eq(
psidx.putmask(psidx < "c", "k").sort_values(),
pidx.putmask(pidx < "c", "k").sort_values(),
)
self.assert_eq(
psidx.putmask(psidx < "c", ["g", "h", "i", "j", "k"]).sort_values(),
pidx.putmask(pidx < "c", ["g", "h", "i", "j", "k"]).sort_values(),
)
self.assert_eq(
psidx.putmask(psidx < "c", ("g", "h", "i", "j", "k")).sort_values(),
pidx.putmask(pidx < "c", ("g", "h", "i", "j", "k")).sort_values(),
)
self.assert_eq(
psidx.putmask(psidx < "c", ps.Index(["g", "h", "i", "j", "k"])).sort_values(),
pidx.putmask(pidx < "c", pd.Index(["g", "h", "i", "j", "k"])).sort_values(),
)
self.assert_eq(
psidx.putmask(psidx < "c", ps.Series(["g", "h", "i", "j", "k"])).sort_values(),
pidx.putmask(pidx < "c", pd.Series(["g", "h", "i", "j", "k"])).sort_values(),
)

self.assertRaises(
ValueError,
lambda: psidx.putmask(psidx < "c", ps.Series(["g", "h"])),
)

self.assertRaises(
ValueError,
lambda: psidx.putmask([True, False], ps.Series(["g", "h", "i", "j", "k"])),
)

def test_multiindex_equal_levels(self):
pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
pmidx2 = pd.MultiIndex.from_tuples([("b", "y"), ("a", "x"), ("c", "z")])
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def infer_pd_series_spark_type(
if dtype == np.dtype("object"):
if len(pser) == 0 or pser.isnull().all():
return types.NullType()
elif hasattr(pser.iloc[0], "__UDT__"):
elif hasattr(pser, "iloc") and hasattr(pser.iloc[0], "__UDT__"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what is this change for ??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point : https://github.com/apache/spark/pull/33744/files#diff-c19199b1eb4ba73f00acb31a2c2c055be95b697fd08049ee6ba54655392adfa5R1984

If the type of input parameter value of this putmask is Index ,
the function infer_pd_series_spark_type raises the exception, because Index type doesn't have iloc attribute.
This is why I fix this part. I thought that it had no effect on the operation of the existing Series type.

Copy link
Contributor Author

@beobest2 beobest2 Sep 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the infer_pd_index_spark_type function would make it cleaner.

Copy link
Contributor

@itholic itholic Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not pass the Index to the infer_pd_series_spark_type.

At least we should change the function name (such as infer_pd_indexops_spark_type) and input & output type of the function, or add a new function and use it as you mentioned.

BTW, actually I think we don't really need to use pandas_udf here, though.

return pser.iloc[0].__UDT__
else:
return from_arrow_type(pa.Array.from_pandas(pser).type, prefer_timestamp_ntz)
Expand Down