@@ -1558,8 +1558,11 @@ def putmask(self, mask, value):
15581558 >>> kidx
15591559 Index(['a', 'b', 'c', 'd', 'e'], dtype='object')
15601560
1561- >>> kidx.putmask([True if x < 2 else False for x in range(5)] , "Koalas").sort_values()
1561+ >>> kidx.putmask(kidx < 'c' , "Koalas").sort_values()
15621562 Index(['Koalas', 'Koalas', 'c', 'd', 'e'], dtype='object')
1563+
1564+ >>> kidx.putmask(kidx < 'c', ks.Index([100, 200, 300, 400, 500])).sort_values()
1565+ Index(['100', '200', 'c', 'd', 'e'], dtype='object')
15631566 """
15641567 scol_name = self ._internal .index_spark_column_names [0 ]
15651568 sdf = self ._internal .spark_frame .select (self .spark .column )
@@ -1569,22 +1572,38 @@ def putmask(self, mask, value):
15691572 sdf , column_name = dist_sequence_col_name
15701573 )
15711574
1575+ replace_col = verify_temp_column_name (sdf , "__replace_column__" )
15721576 masking_col = verify_temp_column_name (sdf , "__masking_column__" )
1573- masking_udf = udf (lambda x : mask [x ], BooleanType ())
15741577
1578+ if isinstance (value , (list , tuple )):
1579+ replace_udf = udf (lambda x : value [x ])
1580+ sdf = sdf .withColumn (replace_col , replace_udf (dist_sequence_col_name ))
1581+ elif isinstance (value , (Index , Series )):
1582+ value = value .to_numpy ().tolist ()
1583+ replace_udf = udf (lambda x : value [x ])
1584+ sdf = sdf .withColumn (replace_col , replace_udf (dist_sequence_col_name ))
1585+ else :
1586+ sdf = sdf .withColumn (replace_col , F .lit (value ))
1587+
1588+ if isinstance (mask , (Index , Series )):
1589+ mask = mask .to_numpy ().tolist ()
1590+ elif not isinstance (mask , list ) and not isinstance (mask , tuple ):
1591+ raise TypeError ("Mask data doesn't support type " "{0}" .format (type (mask ).__name__ ))
1592+
1593+ masking_udf = udf (lambda x : mask [x ], BooleanType ())
15751594 sdf = sdf .withColumn (masking_col , masking_udf (dist_sequence_col_name ))
15761595 # spark_frame here looks like below
1577- # +-------------------------------+-----------------+------------------+
1578- # |__distributed_sequence_column__|__index_level_0__|__masking_column__|
1579- # +-------------------------------+-----------------+------------------+
1580- # | 0| a| true|
1581- # | 3| d| false|
1582- # | 1| b| true|
1583- # | 2| c| false|
1584- # | 4| e| false|
1585- # +-------------------------------+-----------------+------------------+
1586-
1587- cond = F .when (sdf [masking_col ], value ).otherwise (sdf [scol_name ])
1596+ # +-------------------------------+-----------------+------------------+------------------+
1597+ # |__distributed_sequence_column__|__index_level_0__|__replace_column__| __masking_column__|
1598+ # +-------------------------------+-----------------+------------------+------------------+
1599+ # | 0| a| 100| true|
1600+ # | 3| d| 400| false|
1601+ # | 1| b| 200| true|
1602+ # | 2| c| 300| false|
1603+ # | 4| e| 500| false|
1604+ # +-------------------------------+-----------------+------------------+------------------+
1605+
1606+ cond = F .when (sdf [masking_col ], sdf [ replace_col ] ).otherwise (sdf [scol_name ])
15881607 sdf = sdf .select (cond .alias (scol_name ))
15891608
15901609 internal = InternalFrame (spark_frame = sdf , index_map = self ._internal .index_map )
0 commit comments