Skip to content

Commit 2d868d9

Browse files
WeichenXu123holdenk
authored andcommitted
[SPARK-22521][ML] VectorIndexerModel support handle unseen categories via handleInvalid: Python API
## What changes were proposed in this pull request? Add python api for VectorIndexerModel support handle unseen categories via handleInvalid. ## How was this patch tested? doctest added. Author: WeichenXu <[email protected]> Closes #19753 from WeichenXu123/vector_indexer_invalid_py.
1 parent 5855b5c commit 2d868d9

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,17 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
4747
* Options are:
4848
* 'skip': filter out rows with invalid data.
4949
* 'error': throw an error.
50-
* 'keep': put invalid data in a special additional bucket, at index numCategories.
50+
* 'keep': put invalid data in a special additional bucket, at index of the number of
51+
* categories of the feature.
5152
* Default value: "error"
5253
* @group param
5354
*/
5455
@Since("2.3.0")
5556
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
5657
"How to handle invalid data (unseen labels or NULL values). " +
5758
"Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " +
58-
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
59+
"or 'keep' (put invalid data in a special additional bucket, at index of the " +
60+
"number of categories of the feature).",
5961
ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))
6062

6163
setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)
@@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
112114
* - Preserve metadata in transform; if a feature's metadata is already present, do not recompute.
113115
* - Specify certain features to not index, either via a parameter or via existing metadata.
114116
* - Add warning if a categorical feature has only 1 category.
115-
* - Add option for allowing unknown categories.
116117
*/
117118
@Since("1.4.0")
118119
class VectorIndexer @Since("1.4.0") (

python/pyspark/ml/feature.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,7 +2490,8 @@ def setParams(self, inputCols=None, outputCol=None):
24902490

24912491

24922492
@inherit_doc
2493-
class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
2493+
class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
2494+
JavaMLWritable):
24942495
"""
24952496
Class for indexing categorical feature columns in a dataset of `Vector`.
24962497
@@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
25252526
do not recompute.
25262527
- Specify certain features to not index, either via a parameter or via existing metadata.
25272528
- Add warning if a categorical feature has only 1 category.
2528-
- Add option for allowing unknown categories.
25292529
25302530
>>> from pyspark.ml.linalg import Vectors
25312531
>>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
@@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
25562556
True
25572557
>>> loadedModel.categoryMaps == model.categoryMaps
25582558
True
2559+
>>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])
2560+
>>> indexer.getHandleInvalid()
2561+
'error'
2562+
>>> model3 = indexer.setHandleInvalid("skip").fit(df)
2563+
>>> model3.transform(dfWithInvalid).count()
2564+
0
2565+
>>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)
2566+
>>> model4.transform(dfWithInvalid).head().indexed
2567+
DenseVector([2.0, 1.0])
25592568
25602569
.. versionadded:: 1.4.0
25612570
"""
@@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
25652574
"(>= 2). If a feature is found to have > maxCategories values, then " +
25662575
"it is declared continuous.", typeConverter=TypeConverters.toInt)
25672576

2577+
handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " +
2578+
"(unseen labels or NULL values). Options are 'skip' (filter out " +
2579+
"rows with invalid data), 'error' (throw an error), or 'keep' (put " +
2580+
"invalid data in a special additional bucket, at index of the number " +
2581+
"of categories of the feature).",
2582+
typeConverter=TypeConverters.toString)
2583+
25682584
@keyword_only
2569-
def __init__(self, maxCategories=20, inputCol=None, outputCol=None):
2585+
def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
25702586
"""
2571-
__init__(self, maxCategories=20, inputCol=None, outputCol=None)
2587+
__init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
25722588
"""
25732589
super(VectorIndexer, self).__init__()
25742590
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
2575-
self._setDefault(maxCategories=20)
2591+
self._setDefault(maxCategories=20, handleInvalid="error")
25762592
kwargs = self._input_kwargs
25772593
self.setParams(**kwargs)
25782594

25792595
@keyword_only
25802596
@since("1.4.0")
2581-
def setParams(self, maxCategories=20, inputCol=None, outputCol=None):
2597+
def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
25822598
"""
2583-
setParams(self, maxCategories=20, inputCol=None, outputCol=None)
2599+
setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
25842600
Sets params for this VectorIndexer.
25852601
"""
25862602
kwargs = self._input_kwargs

0 commit comments

Comments
 (0)