@@ -2490,7 +2490,8 @@ def setParams(self, inputCols=None, outputCol=None):
2490
2490
2491
2491
2492
2492
@inherit_doc
2493
- class VectorIndexer (JavaEstimator , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
2493
+ class VectorIndexer (JavaEstimator , HasInputCol , HasOutputCol , HasHandleInvalid , JavaMLReadable ,
2494
+ JavaMLWritable ):
2494
2495
"""
2495
2496
Class for indexing categorical feature columns in a dataset of `Vector`.
2496
2497
@@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
2525
2526
do not recompute.
2526
2527
- Specify certain features to not index, either via a parameter or via existing metadata.
2527
2528
- Add warning if a categorical feature has only 1 category.
2528
- - Add option for allowing unknown categories.
2529
2529
2530
2530
>>> from pyspark.ml.linalg import Vectors
2531
2531
>>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
@@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
2556
2556
True
2557
2557
>>> loadedModel.categoryMaps == model.categoryMaps
2558
2558
True
2559
+ >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])
2560
+ >>> indexer.getHandleInvalid()
2561
+ 'error'
2562
+ >>> model3 = indexer.setHandleInvalid("skip").fit(df)
2563
+ >>> model3.transform(dfWithInvalid).count()
2564
+ 0
2565
+ >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)
2566
+ >>> model4.transform(dfWithInvalid).head().indexed
2567
+ DenseVector([2.0, 1.0])
2559
2568
2560
2569
.. versionadded:: 1.4.0
2561
2570
"""
@@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
2565
2574
"(>= 2). If a feature is found to have > maxCategories values, then " +
2566
2575
"it is declared continuous." , typeConverter = TypeConverters .toInt )
2567
2576
2577
+ handleInvalid = Param (Params ._dummy (), "handleInvalid" , "How to handle invalid data " +
2578
+ "(unseen labels or NULL values). Options are 'skip' (filter out " +
2579
+ "rows with invalid data), 'error' (throw an error), or 'keep' (put " +
2580
+ "invalid data in a special additional bucket, at index of the number " +
2581
+ "of categories of the feature)." ,
2582
+ typeConverter = TypeConverters .toString )
2583
+
2568
2584
@keyword_only
2569
- def __init__ (self , maxCategories = 20 , inputCol = None , outputCol = None ):
2585
+ def __init__ (self , maxCategories = 20 , inputCol = None , outputCol = None , handleInvalid = "error" ):
2570
2586
"""
2571
- __init__(self, maxCategories=20, inputCol=None, outputCol=None)
2587
+ __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error" )
2572
2588
"""
2573
2589
super (VectorIndexer , self ).__init__ ()
2574
2590
self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.VectorIndexer" , self .uid )
2575
- self ._setDefault (maxCategories = 20 )
2591
+ self ._setDefault (maxCategories = 20 , handleInvalid = "error" )
2576
2592
kwargs = self ._input_kwargs
2577
2593
self .setParams (** kwargs )
2578
2594
2579
2595
@keyword_only
2580
2596
@since ("1.4.0" )
2581
- def setParams (self , maxCategories = 20 , inputCol = None , outputCol = None ):
2597
+ def setParams (self , maxCategories = 20 , inputCol = None , outputCol = None , handleInvalid = "error" ):
2582
2598
"""
2583
- setParams(self, maxCategories=20, inputCol=None, outputCol=None)
2599
+ setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error" )
2584
2600
Sets params for this VectorIndexer.
2585
2601
"""
2586
2602
kwargs = self ._input_kwargs
0 commit comments