|
36 | 36 | from pandas.api.types import CategoricalDtype, is_hashable |
37 | 37 | from pandas._libs import lib |
38 | 38 |
|
| 39 | +from pyspark.sql.functions import pandas_udf |
39 | 40 | from pyspark.sql import functions as F, Column |
40 | | -from pyspark.sql.types import FractionalType, IntegralType, TimestampType, TimestampNTZType |
| 41 | +from pyspark.sql.types import ( |
| 42 | + FractionalType, |
| 43 | + IntegralType, |
| 44 | + TimestampType, |
| 45 | + TimestampNTZType, |
| 46 | + BooleanType, |
| 47 | + StringType, |
| 48 | +) |
41 | 49 |
|
42 | 50 | from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. |
43 | 51 | from pyspark.pandas._typing import Dtype, Label, Name, Scalar |
|
46 | 54 | from pyspark.pandas.frame import DataFrame |
47 | 55 | from pyspark.pandas.missing.indexes import MissingPandasLikeIndex |
48 | 56 | from pyspark.pandas.series import Series, first_series |
| 57 | +from pyspark.pandas.typedef import infer_pd_series_spark_type |
49 | 58 | from pyspark.pandas.spark import functions as SF |
50 | 59 | from pyspark.pandas.spark.accessors import SparkIndexMethods |
51 | 60 | from pyspark.pandas.utils import ( |
@@ -1927,6 +1936,112 @@ def argmin(self) -> int: |
1927 | 1936 | .first()[0] |
1928 | 1937 | ) |
1929 | 1938 |
|
| 1939 | + def putmask( |
| 1940 | + self, mask: Union[Series, "Index", List, Tuple], value: Union[Series, "Index", List, Tuple] |
| 1941 | + ) -> "Index": |
| 1942 | + """ |
| 1943 | + Return a new Index of the values set with the mask. |
| 1944 | + .. note:: this API can be pretty expensive since it is based on |
| 1945 | + a global sequence internally. |
| 1946 | + Parameters |
| 1947 | + ---------- |
| 1948 | + mask : array-like |
| 1949 | + Boolean mask array. It has to be the same shape as the index. |
| 1950 | + value : array-like |
| 1951 | + Value to put into the index where mask is True. |
| 1952 | + Returns |
| 1953 | + ------- |
| 1954 | + Index |
| 1955 | + Examples |
| 1956 | + ------- |
| 1957 | + >>> psidx = ps.Index([1, 2, 3, 4, 5]) |
| 1958 | + >>> psidx |
| 1959 | + Int64Index([1, 2, 3, 4, 5], dtype='int64') |
| 1960 | + >>> psidx.putmask(psidx > 3, 100).sort_values() |
| 1961 | + Int64Index([1, 2, 3, 100, 100], dtype='int64') |
| 1962 | + >>> psidx.putmask(psidx > 3, ps.Index([100, 200, 300, 400, 500])).sort_values() |
| 1963 | + Int64Index([1, 2, 3, 400, 500], dtype='int64') |
| 1964 | + """ |
| 1965 | + scol_name = self._internal.index_spark_column_names[0] |
| 1966 | + sdf = self._internal.spark_frame.select(self.spark.column) |
| 1967 | + |
| 1968 | + dist_sequence_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__") |
| 1969 | + sdf = InternalFrame.attach_distributed_sequence_column( |
| 1970 | + sdf, column_name=dist_sequence_col_name |
| 1971 | + ) |
| 1972 | + |
| 1973 | + replace_col = verify_temp_column_name(sdf, "__replace_column__") |
| 1974 | + masking_col = verify_temp_column_name(sdf, "__masking_column__") |
| 1975 | + |
| 1976 | + if isinstance(value, (list, tuple, Index, Series)): |
| 1977 | + if isinstance(value, (list, tuple)): |
| 1978 | + pd_value = pd.Series(value) |
| 1979 | + elif isinstance(value, (Series, Index)): |
| 1980 | + pd_value = value.to_pandas() |
| 1981 | + |
| 1982 | + if self.size != pd_value.size: |
| 1983 | + # TODO: We can't support different size of value for now. |
| 1984 | + raise ValueError("value and data must be the same size") |
| 1985 | + |
| 1986 | + replace_return_type = infer_pd_series_spark_type(pd_value, pd_value.dtype) |
| 1987 | + |
| 1988 | + @pandas_udf( |
| 1989 | + returnType=replace_return_type if replace_return_type else StringType() |
| 1990 | + ) # type: ignore |
| 1991 | + def replace_pandas_udf(sequence: pd.Series) -> pd.Series: |
| 1992 | + return pd_value[sequence] |
| 1993 | + |
| 1994 | + sdf = sdf.withColumn(replace_col, replace_pandas_udf(dist_sequence_col_name)) |
| 1995 | + else: |
| 1996 | + sdf = sdf.withColumn(replace_col, F.lit(value)) |
| 1997 | + |
| 1998 | + if isinstance(mask, (list, tuple)): |
| 1999 | + pandas_mask = pd.Series(mask) |
| 2000 | + elif isinstance(mask, (Index, Series)): |
| 2001 | + pandas_mask = mask.to_pandas() |
| 2002 | + else: |
| 2003 | + raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__)) |
| 2004 | + |
| 2005 | + if self.size != pandas_mask.size: |
| 2006 | + raise ValueError("mask and data must be the same size") |
| 2007 | + |
| 2008 | + @pandas_udf(returnType=BooleanType()) # type: ignore |
| 2009 | + def masking_pandas_udf(sequence: pd.Series) -> pd.Series: |
| 2010 | + return pandas_mask[sequence] |
| 2011 | + |
| 2012 | + sdf = sdf.withColumn(masking_col, masking_pandas_udf(dist_sequence_col_name)) |
| 2013 | + |
| 2014 | + # spark_frame here looks like below |
| 2015 | + # +-------------------------------+-----------------+------------------+------------------+ |
| 2016 | + # |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__| |
| 2017 | + # +-------------------------------+-----------------+------------------+------------------+ |
| 2018 | + # | 0| a| 100| true| |
| 2019 | + # | 3| d| 400| false| |
| 2020 | + # | 1| b| 200| true| |
| 2021 | + # | 2| c| 300| false| |
| 2022 | + # | 4| e| 500| false| |
| 2023 | + # +-------------------------------+-----------------+------------------+------------------+ |
| 2024 | + |
| 2025 | + cond = F.when(scol_for(sdf, masking_col), scol_for(sdf, replace_col)).otherwise( |
| 2026 | + scol_for(sdf, scol_name) |
| 2027 | + ) |
| 2028 | + sdf = sdf.select(cond.alias(scol_name)) |
| 2029 | + |
| 2030 | + if sdf.schema[scol_name].nullable != self._internal.index_fields[0].nullable: |
| 2031 | + sdf.schema[scol_name].nullable = self._internal.index_fields[0].nullable |
| 2032 | + sdf = sdf.sql_ctx.createDataFrame(sdf.rdd, sdf.schema) |
| 2033 | + |
| 2034 | + internal = InternalFrame( |
| 2035 | + spark_frame=sdf, |
| 2036 | + index_spark_columns=[ |
| 2037 | + scol_for(sdf, col) for col in self._internal.index_spark_column_names |
| 2038 | + ], |
| 2039 | + index_names=self._internal.index_names, |
| 2040 | + index_fields=self._internal.index_fields, |
| 2041 | + ) |
| 2042 | + |
| 2043 | + return DataFrame(internal).index |
| 2044 | + |
1930 | 2045 | def set_names( |
1931 | 2046 | self, |
1932 | 2047 | names: Union[Name, List[Name]], |
|
0 commit comments