Skip to content

Commit 3771611

Browse files
committed
Fix mask and value to support Series and Index
1 parent 5c31eb3 commit 3771611

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

databricks/koalas/indexes.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,8 +1558,11 @@ def putmask(self, mask, value):
15581558
>>> kidx
15591559
Index(['a', 'b', 'c', 'd', 'e'], dtype='object')
15601560
1561-
>>> kidx.putmask([True if x < 2 else False for x in range(5)], "Koalas").sort_values()
1561+
>>> kidx.putmask(kidx < 'c', "Koalas").sort_values()
15621562
Index(['Koalas', 'Koalas', 'c', 'd', 'e'], dtype='object')
1563+
1564+
>>> kidx.putmask(kidx < 'c', ks.Index([100, 200, 300, 400, 500])).sort_values()
1565+
Index(['100', '200', 'c', 'd', 'e'], dtype='object')
15631566
"""
15641567
scol_name = self._internal.index_spark_column_names[0]
15651568
sdf = self._internal.spark_frame.select(self.spark.column)
@@ -1569,22 +1572,38 @@ def putmask(self, mask, value):
15691572
sdf, column_name=dist_sequence_col_name
15701573
)
15711574

1575+
replace_col = verify_temp_column_name(sdf, "__replace_column__")
15721576
masking_col = verify_temp_column_name(sdf, "__masking_column__")
1573-
masking_udf = udf(lambda x: mask[x], BooleanType())
15741577

1578+
if isinstance(value, (list, tuple)):
1579+
replace_udf = udf(lambda x: value[x])
1580+
sdf = sdf.withColumn(replace_col, replace_udf(dist_sequence_col_name))
1581+
elif isinstance(value, (Index, Series)):
1582+
value = value.to_numpy().tolist()
1583+
replace_udf = udf(lambda x: value[x])
1584+
sdf = sdf.withColumn(replace_col, replace_udf(dist_sequence_col_name))
1585+
else:
1586+
sdf = sdf.withColumn(replace_col, F.lit(value))
1587+
1588+
if isinstance(mask, (Index, Series)):
1589+
mask = mask.to_numpy().tolist()
1590+
elif not isinstance(mask, list) and not isinstance(mask, tuple):
1591+
raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__))
1592+
1593+
masking_udf = udf(lambda x: mask[x], BooleanType())
15751594
sdf = sdf.withColumn(masking_col, masking_udf(dist_sequence_col_name))
15761595
# spark_frame here looks like below
1577-
# +-------------------------------+-----------------+------------------+
1578-
# |__distributed_sequence_column__|__index_level_0__|__masking_column__|
1579-
# +-------------------------------+-----------------+------------------+
1580-
# | 0| a| true|
1581-
# | 3| d| false|
1582-
# | 1| b| true|
1583-
# | 2| c| false|
1584-
# | 4| e| false|
1585-
# +-------------------------------+-----------------+------------------+
1586-
1587-
cond = F.when(sdf[masking_col], value).otherwise(sdf[scol_name])
1596+
# +-------------------------------+-----------------+------------------+------------------+
1597+
# |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__|
1598+
# +-------------------------------+-----------------+------------------+------------------+
1599+
# | 0| a| 100| true|
1600+
# | 3| d| 400| false|
1601+
# | 1| b| 200| true|
1602+
# | 2| c| 300| false|
1603+
# | 4| e| 500| false|
1604+
# +-------------------------------+-----------------+------------------+------------------+
1605+
1606+
cond = F.when(sdf[masking_col], sdf[replace_col]).otherwise(sdf[scol_name])
15881607
sdf = sdf.select(cond.alias(scol_name))
15891608

15901609
internal = InternalFrame(spark_frame=sdf, index_map=self._internal.index_map)

databricks/koalas/tests/test_indexes.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,27 @@ def test_dropna(self):
279279
self.assert_eq((kidx + 1).dropna(), (pidx + 1).dropna())
280280

281281
def test_putmask(self):
282-
pidx = pd.Index(["a", "b", "c", "d", "e"])
282+
pidx = pd.Index([1, 2, 3, 4, 5])
283283
kidx = ks.from_pandas(pidx)
284284

285-
mask = [True if x < 2 else False for x in range(5)]
286-
value = "Koalas"
287-
288285
self.assert_eq(
289-
kidx.putmask(mask, value).sort_values(), pidx.putmask(mask, value).sort_values()
286+
kidx.putmask(kidx < 3, 100).sort_values(), pidx.putmask(pidx < 3, 100).sort_values()
287+
)
288+
self.assert_eq(
289+
kidx.putmask(kidx < 3, [100, 200, 300, 400, 500]).sort_values(),
290+
pidx.putmask(pidx < 3, [100, 200, 300, 400, 500]).sort_values(),
291+
)
292+
self.assert_eq(
293+
kidx.putmask(kidx < 3, (100, 200, 300, 400, 500)).sort_values(),
294+
pidx.putmask(pidx < 3, (100, 200, 300, 400, 500)).sort_values(),
295+
)
296+
self.assert_eq(
297+
kidx.putmask(kidx < 3, ks.Index([100, 200, 300, 400, 500])).sort_values(),
298+
pidx.putmask(pidx < 3, pd.Index([100, 200, 300, 400, 500])).sort_values(),
299+
)
300+
self.assert_eq(
301+
kidx.putmask(kidx < 3, ks.Series([100, 200, 300, 400, 500])).sort_values(),
302+
pidx.putmask(pidx < 3, pd.Series([100, 200, 300, 400, 500])).sort_values(),
290303
)
291304

292305
def test_index_symmetric_difference(self):

0 commit comments

Comments
 (0)