Skip to content

Commit 5997700

Browse files
huaxingaoRobert Kruszewski
authored andcommitted
[SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels
## What changes were proposed in this pull request? The Scala StringIndexerModel has an alternate constructor that will create the model from an array of label strings. Add the corresponding Python API: model = StringIndexerModel.from_labels(["a", "b", "c"]) ## How was this patch tested? Add doctest and unit test. Author: Huaxin Gao <[email protected]> Closes apache#20968 from huaxingao/spark-23828.
1 parent 64f86ac commit 5997700

File tree

2 files changed

+104
-25
lines changed

2 files changed

+104
-25
lines changed

python/pyspark/ml/feature.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,9 +2342,38 @@ def mean(self):
23422342
return self._call_java("mean")
23432343

23442344

2345+
class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol):
2346+
"""
2347+
Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
2348+
"""
2349+
2350+
stringOrderType = Param(Params._dummy(), "stringOrderType",
2351+
"How to order labels of string column. The first label after " +
2352+
"ordering is assigned an index of 0. Supported options: " +
2353+
"frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
2354+
typeConverter=TypeConverters.toString)
2355+
2356+
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
2357+
"or NULL values) in features and label column of string type. " +
2358+
"Options are 'skip' (filter out rows with invalid data), " +
2359+
"error (throw an error), or 'keep' (put invalid data " +
2360+
"in a special additional bucket, at index numLabels).",
2361+
typeConverter=TypeConverters.toString)
2362+
2363+
def __init__(self, *args):
2364+
super(_StringIndexerParams, self).__init__(*args)
2365+
self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
2366+
2367+
@since("2.3.0")
2368+
def getStringOrderType(self):
2369+
"""
2370+
Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
2371+
"""
2372+
return self.getOrDefault(self.stringOrderType)
2373+
2374+
23452375
@inherit_doc
2346-
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
2347-
JavaMLWritable):
2376+
class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
23482377
"""
23492378
A label indexer that maps a string column of labels to an ML column of label indices.
23502379
If the input column is numeric, we cast it to string and index the string values.
@@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
23882417
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
23892418
... key=lambda x: x[0])
23902419
[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
2420+
>>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
2421+
... inputCol="label", outputCol="indexed", handleInvalid="error")
2422+
>>> result = fromlabelsModel.transform(stringIndDf)
2423+
>>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
2424+
... key=lambda x: x[0])
2425+
[(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
23912426
23922427
.. versionadded:: 1.4.0
23932428
"""
23942429

2395-
stringOrderType = Param(Params._dummy(), "stringOrderType",
2396-
"How to order labels of string column. The first label after " +
2397-
"ordering is assigned an index of 0. Supported options: " +
2398-
"frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
2399-
typeConverter=TypeConverters.toString)
2400-
2401-
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
2402-
"or NULL values) in features and label column of string type. " +
2403-
"Options are 'skip' (filter out rows with invalid data), " +
2404-
"error (throw an error), or 'keep' (put invalid data " +
2405-
"in a special additional bucket, at index numLabels).",
2406-
typeConverter=TypeConverters.toString)
2407-
24082430
@keyword_only
24092431
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
24102432
stringOrderType="frequencyDesc"):
@@ -2414,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
24142436
"""
24152437
super(StringIndexer, self).__init__()
24162438
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
2417-
self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
24182439
kwargs = self._input_kwargs
24192440
self.setParams(**kwargs)
24202441

@@ -2440,21 +2461,33 @@ def setStringOrderType(self, value):
24402461
"""
24412462
return self._set(stringOrderType=value)
24422463

2443-
@since("2.3.0")
2444-
def getStringOrderType(self):
2445-
"""
2446-
Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
2447-
"""
2448-
return self.getOrDefault(self.stringOrderType)
2449-
24502464

2451-
class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
2465+
class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
24522466
"""
24532467
Model fitted by :py:class:`StringIndexer`.
24542468
24552469
.. versionadded:: 1.4.0
24562470
"""
24572471

2472+
@classmethod
2473+
@since("2.4.0")
2474+
def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None):
2475+
"""
2476+
Construct the model directly from an array of label strings,
2477+
requires an active SparkContext.
2478+
"""
2479+
sc = SparkContext._active_spark_context
2480+
java_class = sc._gateway.jvm.java.lang.String
2481+
jlabels = StringIndexerModel._new_java_array(labels, java_class)
2482+
model = StringIndexerModel._create_from_java_class(
2483+
"org.apache.spark.ml.feature.StringIndexerModel", jlabels)
2484+
model.setInputCol(inputCol)
2485+
if outputCol is not None:
2486+
model.setOutputCol(outputCol)
2487+
if handleInvalid is not None:
2488+
model.setHandleInvalid(handleInvalid)
2489+
return model
2490+
24582491
@property
24592492
@since("1.5.0")
24602493
def labels(self):
@@ -2463,6 +2496,13 @@ def labels(self):
24632496
"""
24642497
return self._call_java("labels")
24652498

2499+
@since("2.4.0")
2500+
def setHandleInvalid(self, value):
2501+
"""
2502+
Sets the value of :py:attr:`handleInvalid`.
2503+
"""
2504+
return self._set(handleInvalid=value)
2505+
24662506

24672507
@inherit_doc
24682508
class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):

python/pyspark/ml/tests.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,43 @@ def test_string_indexer_handle_invalid(self):
798798
expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
799799
self.assertEqual(actual2, expected2)
800800

801+
def test_string_indexer_from_labels(self):
802+
model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label",
803+
outputCol="indexed", handleInvalid="keep")
804+
self.assertEqual(model.labels, ["a", "b", "c"])
805+
806+
df1 = self.spark.createDataFrame([
807+
(0, "a"),
808+
(1, "c"),
809+
(2, None),
810+
(3, "b"),
811+
(4, "b")], ["id", "label"])
812+
813+
result1 = model.transform(df1)
814+
actual1 = result1.select("id", "indexed").collect()
815+
expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0),
816+
Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)]
817+
self.assertEqual(actual1, expected1)
818+
819+
model_empty_labels = StringIndexerModel.from_labels(
820+
[], inputCol="label", outputCol="indexed", handleInvalid="keep")
821+
actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect()
822+
expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0),
823+
Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)]
824+
self.assertEqual(actual2, expected2)
825+
826+
# Test model with default settings can transform
827+
model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label")
828+
df2 = self.spark.createDataFrame([
829+
(0, "a"),
830+
(1, "c"),
831+
(2, "b"),
832+
(3, "b"),
833+
(4, "b")], ["id", "label"])
834+
transformed_list = model_default.transform(df2)\
835+
.select(model_default.getOrDefault(model_default.outputCol)).collect()
836+
self.assertEqual(len(transformed_list), 5)
837+
801838

802839
class HasInducedError(Params):
803840

@@ -2095,9 +2132,11 @@ def test_java_params(self):
20952132
ParamTests.check_params(self, cls(), check_params_exist=False)
20962133

20972134
# Additional classes that need explicit construction
2098-
from pyspark.ml.feature import CountVectorizerModel
2135+
from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel
20992136
ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'),
21002137
check_params_exist=False)
2138+
ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'),
2139+
check_params_exist=False)
21012140

21022141

21032142
def _squared_distance(a, b):

0 commit comments

Comments
 (0)