@@ -2342,9 +2342,38 @@ def mean(self):
2342
2342
return self ._call_java ("mean" )
2343
2343
2344
2344
2345
+ class _StringIndexerParams (JavaParams , HasHandleInvalid , HasInputCol , HasOutputCol ):
2346
+ """
2347
+ Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
2348
+ """
2349
+
2350
+ stringOrderType = Param (Params ._dummy (), "stringOrderType" ,
2351
+ "How to order labels of string column. The first label after " +
2352
+ "ordering is assigned an index of 0. Supported options: " +
2353
+ "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc." ,
2354
+ typeConverter = TypeConverters .toString )
2355
+
2356
+ handleInvalid = Param (Params ._dummy (), "handleInvalid" , "how to handle invalid data (unseen " +
2357
+ "or NULL values) in features and label column of string type. " +
2358
+ "Options are 'skip' (filter out rows with invalid data), " +
2359
+ "error (throw an error), or 'keep' (put invalid data " +
2360
+ "in a special additional bucket, at index numLabels)." ,
2361
+ typeConverter = TypeConverters .toString )
2362
+
2363
+ def __init__ (self , * args ):
2364
+ super (_StringIndexerParams , self ).__init__ (* args )
2365
+ self ._setDefault (handleInvalid = "error" , stringOrderType = "frequencyDesc" )
2366
+
2367
+ @since ("2.3.0" )
2368
+ def getStringOrderType (self ):
2369
+ """
2370
+ Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
2371
+ """
2372
+ return self .getOrDefault (self .stringOrderType )
2373
+
2374
+
2345
2375
@inherit_doc
2346
- class StringIndexer (JavaEstimator , HasInputCol , HasOutputCol , HasHandleInvalid , JavaMLReadable ,
2347
- JavaMLWritable ):
2376
+ class StringIndexer (JavaEstimator , _StringIndexerParams , JavaMLReadable , JavaMLWritable ):
2348
2377
"""
2349
2378
A label indexer that maps a string column of labels to an ML column of label indices.
2350
2379
If the input column is numeric, we cast it to string and index the string values.
@@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
2388
2417
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
2389
2418
... key=lambda x: x[0])
2390
2419
[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
2420
+ >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
2421
+ ... inputCol="label", outputCol="indexed", handleInvalid="error")
2422
+ >>> result = fromlabelsModel.transform(stringIndDf)
2423
+ >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
2424
+ ... key=lambda x: x[0])
2425
+ [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
2391
2426
2392
2427
.. versionadded:: 1.4.0
2393
2428
"""
2394
2429
2395
- stringOrderType = Param (Params ._dummy (), "stringOrderType" ,
2396
- "How to order labels of string column. The first label after " +
2397
- "ordering is assigned an index of 0. Supported options: " +
2398
- "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc." ,
2399
- typeConverter = TypeConverters .toString )
2400
-
2401
- handleInvalid = Param (Params ._dummy (), "handleInvalid" , "how to handle invalid data (unseen " +
2402
- "or NULL values) in features and label column of string type. " +
2403
- "Options are 'skip' (filter out rows with invalid data), " +
2404
- "error (throw an error), or 'keep' (put invalid data " +
2405
- "in a special additional bucket, at index numLabels)." ,
2406
- typeConverter = TypeConverters .toString )
2407
-
2408
2430
@keyword_only
2409
2431
def __init__ (self , inputCol = None , outputCol = None , handleInvalid = "error" ,
2410
2432
stringOrderType = "frequencyDesc" ):
@@ -2414,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
2414
2436
"""
2415
2437
super (StringIndexer , self ).__init__ ()
2416
2438
self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.StringIndexer" , self .uid )
2417
- self ._setDefault (handleInvalid = "error" , stringOrderType = "frequencyDesc" )
2418
2439
kwargs = self ._input_kwargs
2419
2440
self .setParams (** kwargs )
2420
2441
@@ -2440,21 +2461,33 @@ def setStringOrderType(self, value):
2440
2461
"""
2441
2462
return self ._set (stringOrderType = value )
2442
2463
2443
- @since ("2.3.0" )
2444
- def getStringOrderType (self ):
2445
- """
2446
- Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
2447
- """
2448
- return self .getOrDefault (self .stringOrderType )
2449
-
2450
2464
2451
- class StringIndexerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
2465
+ class StringIndexerModel (JavaModel , _StringIndexerParams , JavaMLReadable , JavaMLWritable ):
2452
2466
"""
2453
2467
Model fitted by :py:class:`StringIndexer`.
2454
2468
2455
2469
.. versionadded:: 1.4.0
2456
2470
"""
2457
2471
2472
+ @classmethod
2473
+ @since ("2.4.0" )
2474
+ def from_labels (cls , labels , inputCol , outputCol = None , handleInvalid = None ):
2475
+ """
2476
+ Construct the model directly from an array of label strings,
2477
+ requires an active SparkContext.
2478
+ """
2479
+ sc = SparkContext ._active_spark_context
2480
+ java_class = sc ._gateway .jvm .java .lang .String
2481
+ jlabels = StringIndexerModel ._new_java_array (labels , java_class )
2482
+ model = StringIndexerModel ._create_from_java_class (
2483
+ "org.apache.spark.ml.feature.StringIndexerModel" , jlabels )
2484
+ model .setInputCol (inputCol )
2485
+ if outputCol is not None :
2486
+ model .setOutputCol (outputCol )
2487
+ if handleInvalid is not None :
2488
+ model .setHandleInvalid (handleInvalid )
2489
+ return model
2490
+
2458
2491
@property
2459
2492
@since ("1.5.0" )
2460
2493
def labels (self ):
@@ -2463,6 +2496,13 @@ def labels(self):
2463
2496
"""
2464
2497
return self ._call_java ("labels" )
2465
2498
2499
+ @since ("2.4.0" )
2500
+ def setHandleInvalid (self , value ):
2501
+ """
2502
+ Sets the value of :py:attr:`handleInvalid`.
2503
+ """
2504
+ return self ._set (handleInvalid = value )
2505
+
2466
2506
2467
2507
@inherit_doc
2468
2508
class IndexToString (JavaTransformer , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
0 commit comments