Skip to content

Commit d2d2a5d

Browse files
zhengruifengyanboliang
authored andcommitted
[SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid
## What changes were proposed in this pull request? 1, HasHandleInvaild support override 2, Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid ## How was this patch tested? existing tests [JIRA](https://issues.apache.org/jira/browse/SPARK-18619) Author: Zheng RuiFeng <[email protected]> Closes apache#18582 from zhengruifeng/heritate_HasHandleInvalid.
1 parent aaad34d commit d2d2a5d

File tree

9 files changed

+53
-80
lines changed

9 files changed

+53
-80
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
2424
import org.apache.spark.ml.Model
2525
import org.apache.spark.ml.attribute.NominalAttribute
2626
import org.apache.spark.ml.param._
27-
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
27+
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
2828
import org.apache.spark.ml.util._
2929
import org.apache.spark.sql._
3030
import org.apache.spark.sql.expressions.UserDefinedFunction
@@ -36,7 +36,8 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3636
*/
3737
@Since("1.4.0")
3838
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
39-
extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable {
39+
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
40+
with DefaultParamsWritable {
4041

4142
@Since("1.4.0")
4243
def this() = this(Identifiable.randomUID("bucketizer"))
@@ -84,17 +85,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
8485
* Default: "error"
8586
* @group param
8687
*/
87-
// TODO: SPARK-18619 Make Bucketizer inherit from HasHandleInvalid.
8888
@Since("2.1.0")
89-
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
90-
"invalid entries. Options are skip (filter out rows with invalid values), " +
89+
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
90+
"how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
9191
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
9292
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
9393

94-
/** @group getParam */
95-
@Since("2.1.0")
96-
def getHandleInvalid: String = $(handleInvalid)
97-
9894
/** @group setParam */
9995
@Since("2.1.0")
10096
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging
2222
import org.apache.spark.ml._
2323
import org.apache.spark.ml.attribute.NominalAttribute
2424
import org.apache.spark.ml.param._
25-
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
25+
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
2626
import org.apache.spark.ml.util._
2727
import org.apache.spark.sql.Dataset
2828
import org.apache.spark.sql.types.StructType
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType
3131
* Params for [[QuantileDiscretizer]].
3232
*/
3333
private[feature] trait QuantileDiscretizerBase extends Params
34-
with HasInputCol with HasOutputCol {
34+
with HasHandleInvalid with HasInputCol with HasOutputCol {
3535

3636
/**
3737
* Number of buckets (quantiles, or categories) into which data points are grouped. Must
@@ -72,18 +72,13 @@ private[feature] trait QuantileDiscretizerBase extends Params
7272
* Default: "error"
7373
* @group param
7474
*/
75-
// TODO: SPARK-18619 Make QuantileDiscretizer inherit from HasHandleInvalid.
7675
@Since("2.1.0")
77-
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
78-
"invalid entries. Options are skip (filter out rows with invalid values), " +
76+
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
77+
"how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
7978
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
8079
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
8180
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
8281

83-
/** @group getParam */
84-
@Since("2.1.0")
85-
def getHandleInvalid: String = $(handleInvalid)
86-
8782
}
8883

8984
/**

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineS
2727
import org.apache.spark.ml.attribute.AttributeGroup
2828
import org.apache.spark.ml.linalg.VectorUDT
2929
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
30-
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
30+
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
3131
import org.apache.spark.ml.util._
3232
import org.apache.spark.sql.{DataFrame, Dataset}
3333
import org.apache.spark.sql.types._
@@ -108,7 +108,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
108108
@Experimental
109109
@Since("1.5.0")
110110
class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
111-
extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
111+
extends Estimator[RFormulaModel] with RFormulaBase with HasHandleInvalid
112+
with DefaultParamsWritable {
112113

113114
@Since("1.5.0")
114115
def this() = this(Identifiable.randomUID("rFormula"))
@@ -141,8 +142,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
141142
* @group param
142143
*/
143144
@Since("2.3.0")
144-
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
145-
"invalid data (unseen labels or NULL values). " +
145+
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
146+
"How to handle invalid data (unseen labels or NULL values). " +
146147
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
147148
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
148149
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
@@ -152,10 +153,6 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
152153
@Since("2.3.0")
153154
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
154155

155-
/** @group getParam */
156-
@Since("2.3.0")
157-
def getHandleInvalid: String = $(handleInvalid)
158-
159156
/** @group setParam */
160157
@Since("1.5.0")
161158
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.Since
2626
import org.apache.spark.ml.{Estimator, Model, Transformer}
2727
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
2828
import org.apache.spark.ml.param._
29-
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
29+
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
3030
import org.apache.spark.ml.util._
3131
import org.apache.spark.sql.{DataFrame, Dataset}
3232
import org.apache.spark.sql.functions._
@@ -36,7 +36,8 @@ import org.apache.spark.util.collection.OpenHashMap
3636
/**
3737
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
3838
*/
39-
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
39+
private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol
40+
with HasOutputCol {
4041

4142
/**
4243
* Param for how to handle invalid data (unseen labels or NULL values).
@@ -47,18 +48,14 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
4748
* @group param
4849
*/
4950
@Since("1.6.0")
50-
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
51-
"invalid data (unseen labels or NULL values). " +
51+
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
52+
"How to handle invalid data (unseen labels or NULL values). " +
5253
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
5354
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
5455
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
5556

5657
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
5758

58-
/** @group getParam */
59-
@Since("1.6.0")
60-
def getHandleInvalid: String = $(handleInvalid)
61-
6259
/**
6360
* Param for how to order labels of string column. The first label after ordering is assigned
6461
* an index of 0.

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ private[shared] object SharedParamsCodeGen {
6767
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
6868
"will filter out rows with bad values), or error (which will throw an error). More " +
6969
"options may be added later",
70-
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
70+
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))", finalFields = false),
7171
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
7272
" before fitting the model", Some("true")),
7373
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ private[ml] trait HasHandleInvalid extends Params {
273273
* Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.
274274
* @group param
275275
*/
276-
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
276+
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
277277

278278
/** @group getParam */
279279
final def getHandleInvalid: String = $(handleInvalid)

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
171171
*
172172
* @group param
173173
*/
174-
@Since("2.3.0")
174+
@Since("2.0.0")
175175
final override val solver: Param[String] = new Param[String](this, "solver",
176176
"The solver algorithm for optimization. Supported options: " +
177177
s"${supportedSolvers.mkString(", ")}. (Default irls)",

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
6464
*
6565
* @group param
6666
*/
67-
@Since("2.3.0")
67+
@Since("1.6.0")
6868
final override val solver: Param[String] = new Param[String](this, "solver",
6969
"The solver algorithm for optimization. Supported options: " +
7070
s"${supportedSolvers.mkString(", ")}. (Default auto)",
@@ -194,7 +194,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
194194
*/
195195
@Since("1.6.0")
196196
def setSolver(value: String): this.type = set(solver, value)
197-
setDefault(solver -> AUTO)
197+
setDefault(solver -> Auto)
198198

199199
/**
200200
* Suggested depth for treeAggregate (greater than or equal to 2).
@@ -224,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
224224
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
225225
instr.logNumFeatures(numFeatures)
226226

227-
if (($(solver) == AUTO &&
228-
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
227+
if (($(solver) == Auto &&
228+
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) {
229229
// For low dimensional data, WeightedLeastSquares is more efficient since the
230230
// training algorithm only requires one pass through the data. (SPARK-10668)
231231

@@ -460,16 +460,16 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
460460
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
461461

462462
/** String name for "auto". */
463-
private[regression] val AUTO = "auto"
463+
private[regression] val Auto = "auto"
464464

465465
/** String name for "normal". */
466-
private[regression] val NORMAL = "normal"
466+
private[regression] val Normal = "normal"
467467

468468
/** String name for "l-bfgs". */
469469
private[regression] val LBFGS = "l-bfgs"
470470

471471
/** Set of solvers that LinearRegression supports. */
472-
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
472+
private[regression] val supportedSolvers = Array(Auto, Normal, LBFGS)
473473
}
474474

475475
/**

python/pyspark/ml/feature.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)
314314

315315

316316
@inherit_doc
317-
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
317+
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
318+
JavaMLReadable, JavaMLWritable):
318319
"""
319320
Maps a column of continuous features to a column of feature buckets.
320321
@@ -398,20 +399,6 @@ def getSplits(self):
398399
"""
399400
return self.getOrDefault(self.splits)
400401

401-
@since("2.1.0")
402-
def setHandleInvalid(self, value):
403-
"""
404-
Sets the value of :py:attr:`handleInvalid`.
405-
"""
406-
return self._set(handleInvalid=value)
407-
408-
@since("2.1.0")
409-
def getHandleInvalid(self):
410-
"""
411-
Gets the value of :py:attr:`handleInvalid` or its default value.
412-
"""
413-
return self.getOrDefault(self.handleInvalid)
414-
415402

416403
@inherit_doc
417404
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -1623,7 +1610,8 @@ def getDegree(self):
16231610

16241611

16251612
@inherit_doc
1626-
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
1613+
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
1614+
JavaMLReadable, JavaMLWritable):
16271615
"""
16281616
.. note:: Experimental
16291617
@@ -1743,20 +1731,6 @@ def getRelativeError(self):
17431731
"""
17441732
return self.getOrDefault(self.relativeError)
17451733

1746-
@since("2.1.0")
1747-
def setHandleInvalid(self, value):
1748-
"""
1749-
Sets the value of :py:attr:`handleInvalid`.
1750-
"""
1751-
return self._set(handleInvalid=value)
1752-
1753-
@since("2.1.0")
1754-
def getHandleInvalid(self):
1755-
"""
1756-
Gets the value of :py:attr:`handleInvalid` or its default value.
1757-
"""
1758-
return self.getOrDefault(self.handleInvalid)
1759-
17601734
def _create_model(self, java_model):
17611735
"""
17621736
Private method to convert the java_model to a Python model.
@@ -2977,7 +2951,8 @@ def explainedVariance(self):
29772951

29782952

29792953
@inherit_doc
2980-
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
2954+
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid,
2955+
JavaMLReadable, JavaMLWritable):
29812956
"""
29822957
.. note:: Experimental
29832958
@@ -3020,6 +2995,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
30202995
True
30212996
>>> loadedRF.getLabelCol() == rf.getLabelCol()
30222997
True
2998+
>>> loadedRF.getHandleInvalid() == rf.getHandleInvalid()
2999+
True
30233000
>>> str(loadedRF)
30243001
'RFormula(y ~ x + s) (uid=...)'
30253002
>>> modelPath = temp_path + "/rFormulaModel"
@@ -3058,26 +3035,37 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
30583035
"RFormula drops the same category as R when encoding strings.",
30593036
typeConverter=TypeConverters.toString)
30603037

3038+
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
3039+
"Options are 'skip' (filter out rows with invalid values), " +
3040+
"'error' (throw an error), or 'keep' (put invalid data in a special " +
3041+
"additional bucket, at index numLabels).",
3042+
typeConverter=TypeConverters.toString)
3043+
30613044
@keyword_only
30623045
def __init__(self, formula=None, featuresCol="features", labelCol="label",
3063-
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
3046+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
3047+
handleInvalid="error"):
30643048
"""
30653049
__init__(self, formula=None, featuresCol="features", labelCol="label", \
3066-
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
3050+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
3051+
handleInvalid="error")
30673052
"""
30683053
super(RFormula, self).__init__()
30693054
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
3070-
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
3055+
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
3056+
handleInvalid="error")
30713057
kwargs = self._input_kwargs
30723058
self.setParams(**kwargs)
30733059

30743060
@keyword_only
30753061
@since("1.5.0")
30763062
def setParams(self, formula=None, featuresCol="features", labelCol="label",
3077-
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
3063+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
3064+
handleInvalid="error"):
30783065
"""
30793066
setParams(self, formula=None, featuresCol="features", labelCol="label", \
3080-
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
3067+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
3068+
handleInvalid="error")
30813069
Sets params for RFormula.
30823070
"""
30833071
kwargs = self._input_kwargs

0 commit comments

Comments
 (0)