Skip to content

Commit 1112fc6

Browse files
huaxingaodongjoon-hyun
authored andcommitted
[SPARK-29867][ML][PYTHON] Add __repr__ in Python ML Models
### What changes were proposed in this pull request? Add ```__repr__``` in Python ML Models ### Why are the changes needed? In Python ML Models, some of them have ```__repr__```, others don't. In the doctest, when calling Model.setXXX, some of the Models print out the xxxModel... correctly, some of them can't because of lacking the ```__repr__``` method. For example: ``` >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10) >>> model = gm.fit(df) >>> model.setPredictionCol("newPrediction") GaussianMixture... ``` After the change, the above code will become the following: ``` >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10) >>> model = gm.fit(df) >>> model.setPredictionCol("newPrediction") GaussianMixtureModel... ``` ### Does this PR introduce any user-facing change? Yes. ### How was this patch tested? doctest Closes apache#26489 from huaxingao/spark-29876. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 6d6b233 commit 1112fc6

File tree

8 files changed

+44
-28
lines changed

8 files changed

+44
-28
lines changed

python/pyspark/ml/classification.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
192192
0.01
193193
>>> model = svm.fit(df)
194194
>>> model.setPredictionCol("newPrediction")
195-
LinearSVC...
195+
LinearSVCModel...
196196
>>> model.getPredictionCol()
197197
'newPrediction'
198198
>>> model.setThreshold(0.5)
199-
LinearSVC...
199+
LinearSVCModel...
200200
>>> model.getThreshold()
201201
0.5
202202
>>> model.coefficients
@@ -812,9 +812,6 @@ def evaluate(self, dataset):
812812
java_blr_summary = self._call_java("evaluate", dataset)
813813
return BinaryLogisticRegressionSummary(java_blr_summary)
814814

815-
def __repr__(self):
816-
return self._call_java("toString")
817-
818815

819816
class LogisticRegressionSummary(JavaWrapper):
820817
"""
@@ -1921,7 +1918,7 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
19211918
>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
19221919
>>> model = nb.fit(df)
19231920
>>> model.setFeaturesCol("features")
1924-
NaiveBayes_...
1921+
NaiveBayesModel...
19251922
>>> model.getSmoothing()
19261923
1.0
19271924
>>> model.pi
@@ -2114,7 +2111,7 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
21142111
100
21152112
>>> model = mlp.fit(df)
21162113
>>> model.setFeaturesCol("features")
2117-
MultilayerPerceptronClassifier...
2114+
MultilayerPerceptronClassificationModel...
21182115
>>> model.layers
21192116
[2, 2, 2]
21202117
>>> model.weights.size

python/pyspark/ml/clustering.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
234234
>>> model.getFeaturesCol()
235235
'features'
236236
>>> model.setPredictionCol("newPrediction")
237-
GaussianMixture...
237+
GaussianMixtureModel...
238238
>>> model.predict(df.head().features)
239239
2
240240
>>> model.predictProbability(df.head().features)
@@ -532,7 +532,7 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
532532
>>> model.getDistanceMeasure()
533533
'euclidean'
534534
>>> model.setPredictionCol("newPrediction")
535-
KMeans...
535+
KMeansModel...
536536
>>> model.predict(df.head().features)
537537
0
538538
>>> centers = model.clusterCenters()
@@ -794,7 +794,7 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav
794794
>>> model.getMaxIter()
795795
20
796796
>>> model.setPredictionCol("newPrediction")
797-
BisectingKMeans...
797+
BisectingKMeansModel...
798798
>>> model.predict(df.head().features)
799799
0
800800
>>> centers = model.clusterCenters()
@@ -1265,6 +1265,8 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
12651265
10
12661266
>>> lda.clear(lda.maxIter)
12671267
>>> model = lda.fit(df)
1268+
>>> model.setSeed(1)
1269+
DistributedLDAModel...
12681270
>>> model.getTopicDistributionCol()
12691271
'topicDistribution'
12701272
>>> model.isDistributed()

python/pyspark/ml/feature.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
337337
>>> model = brp.fit(df)
338338
>>> model.getBucketLength()
339339
1.0
340+
>>> model.setOutputCol("hashes")
341+
BucketedRandomProjectionLSHModel...
340342
>>> model.transform(df).head()
341343
Row(id=0, features=DenseVector([-1.0, -1.0]), hashes=[DenseVector([-1.0])])
342344
>>> data2 = [(4, Vectors.dense([2.0, 2.0 ]),),
@@ -733,6 +735,8 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav
733735
>>> cv.setOutputCol("vectors")
734736
CountVectorizer...
735737
>>> model = cv.fit(df)
738+
>>> model.setInputCol("raw")
739+
CountVectorizerModel...
736740
>>> model.transform(df).show(truncate=False)
737741
+-----+---------------+-------------------------+
738742
|label|raw |vectors |
@@ -1345,6 +1349,8 @@ class IDF(JavaEstimator, _IDFParams, JavaMLReadable, JavaMLWritable):
13451349
>>> idf.setOutputCol("idf")
13461350
IDF...
13471351
>>> model = idf.fit(df)
1352+
>>> model.setOutputCol("idf")
1353+
IDFModel...
13481354
>>> model.getMinDocFreq()
13491355
3
13501356
>>> model.idf
@@ -1519,6 +1525,8 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
15191525
>>> imputer.getRelativeError()
15201526
0.001
15211527
>>> model = imputer.fit(df)
1528+
>>> model.setInputCols(["a", "b"])
1529+
ImputerModel...
15221530
>>> model.getStrategy()
15231531
'mean'
15241532
>>> model.surrogateDF.show()
@@ -1810,7 +1818,7 @@ class MaxAbsScaler(JavaEstimator, _MaxAbsScalerParams, JavaMLReadable, JavaMLWri
18101818
MaxAbsScaler...
18111819
>>> model = maScaler.fit(df)
18121820
>>> model.setOutputCol("scaledOutput")
1813-
MaxAbsScaler...
1821+
MaxAbsScalerModel...
18141822
>>> model.transform(df).show()
18151823
+-----+------------+
18161824
| a|scaledOutput|
@@ -1928,6 +1936,8 @@ class MinHashLSH(_LSH, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, JavaM
19281936
>>> mh.setSeed(12345)
19291937
MinHashLSH...
19301938
>>> model = mh.fit(df)
1939+
>>> model.setInputCol("features")
1940+
MinHashLSHModel...
19311941
>>> model.transform(df).head()
19321942
Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668...
19331943
>>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),
@@ -2056,7 +2066,7 @@ class MinMaxScaler(JavaEstimator, _MinMaxScalerParams, JavaMLReadable, JavaMLWri
20562066
MinMaxScaler...
20572067
>>> model = mmScaler.fit(df)
20582068
>>> model.setOutputCol("scaledOutput")
2059-
MinMaxScaler...
2069+
MinMaxScalerModel...
20602070
>>> model.originalMin
20612071
DenseVector([0.0])
20622072
>>> model.originalMax
@@ -2421,6 +2431,8 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW
24212431
>>> ohe.setOutputCols(["output"])
24222432
OneHotEncoder...
24232433
>>> model = ohe.fit(df)
2434+
>>> model.setOutputCols(["output"])
2435+
OneHotEncoderModel...
24242436
>>> model.getHandleInvalid()
24252437
'error'
24262438
>>> model.transform(df).head().output
@@ -2935,7 +2947,7 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
29352947
RobustScaler...
29362948
>>> model = scaler.fit(df)
29372949
>>> model.setOutputCol("output")
2938-
RobustScaler...
2950+
RobustScalerModel...
29392951
>>> model.median
29402952
DenseVector([2.0, -2.0])
29412953
>>> model.range
@@ -3330,7 +3342,7 @@ class StandardScaler(JavaEstimator, _StandardScalerParams, JavaMLReadable, JavaM
33303342
>>> model.getInputCol()
33313343
'a'
33323344
>>> model.setOutputCol("output")
3333-
StandardScaler...
3345+
StandardScalerModel...
33343346
>>> model.mean
33353347
DenseVector([1.0])
33363348
>>> model.std
@@ -3490,6 +3502,8 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
34903502
>>> stringIndexer.setHandleInvalid("error")
34913503
StringIndexer...
34923504
>>> model = stringIndexer.fit(stringIndDf)
3505+
>>> model.setHandleInvalid("error")
3506+
StringIndexerModel...
34933507
>>> td = model.transform(stringIndDf)
34943508
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
34953509
... key=lambda x: x[0])
@@ -4166,7 +4180,7 @@ class VectorIndexer(JavaEstimator, _VectorIndexerParams, JavaMLReadable, JavaMLW
41664180
>>> indexer.getHandleInvalid()
41674181
'error'
41684182
>>> model.setOutputCol("output")
4169-
VectorIndexer...
4183+
VectorIndexerModel...
41704184
>>> model.transform(df).head().output
41714185
DenseVector([1.0, 0.0])
41724186
>>> model.numFeatures
@@ -4487,6 +4501,8 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
44874501
>>> model = word2Vec.fit(doc)
44884502
>>> model.getMinCount()
44894503
5
4504+
>>> model.setInputCol("sentence")
4505+
Word2VecModel...
44904506
>>> model.getVectors().show()
44914507
+----+--------------------+
44924508
|word| vector|
@@ -4714,7 +4730,7 @@ class PCA(JavaEstimator, _PCAParams, JavaMLReadable, JavaMLWritable):
47144730
>>> model.getK()
47154731
2
47164732
>>> model.setOutputCol("output")
4717-
PCA...
4733+
PCAModel...
47184734
>>> model.transform(df).collect()[0].output
47194735
DenseVector([1.648..., -4.013...])
47204736
>>> model.explainedVariance
@@ -5139,6 +5155,8 @@ class ChiSqSelector(JavaEstimator, _ChiSqSelectorParams, JavaMLReadable, JavaMLW
51395155
>>> model = selector.fit(df)
51405156
>>> model.getFeaturesCol()
51415157
'features'
5158+
>>> model.setFeaturesCol("features")
5159+
ChiSqSelectorModel...
51425160
>>> model.transform(df).head().selectedFeatures
51435161
DenseVector([18.0])
51445162
>>> model.selectedFeatures

python/pyspark/ml/fpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
166166
>>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
167167
>>> fpm = fp.fit(data)
168168
>>> fpm.setPredictionCol("newPrediction")
169-
FPGrowth...
169+
FPGrowthModel...
170170
>>> fpm.freqItemsets.show(5)
171171
+---------+----+
172172
| items|freq|

python/pyspark/ml/recommendation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
225225
>>> model = als.fit(df)
226226
>>> model.getUserCol()
227227
'user'
228+
>>> model.setUserCol("user")
229+
ALSModel...
228230
>>> model.getItemCol()
229231
'item'
230232
>>> model.setPredictionCol("newPrediction")

python/pyspark/ml/regression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, J
105105
LinearRegression...
106106
>>> model = lr.fit(df)
107107
>>> model.setFeaturesCol("features")
108-
LinearRegression...
108+
LinearRegressionModel...
109109
>>> model.setPredictionCol("newPrediction")
110-
LinearRegression...
110+
LinearRegressionModel...
111111
>>> model.getMaxIter()
112112
5
113113
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -591,7 +591,7 @@ class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol,
591591
>>> ir = IsotonicRegression()
592592
>>> model = ir.fit(df)
593593
>>> model.setFeaturesCol("features")
594-
IsotonicRegression...
594+
IsotonicRegressionModel...
595595
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
596596
>>> model.transform(test0).head().prediction
597597
0.0
@@ -1546,7 +1546,7 @@ class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
15461546
>>> aftsr.clear(aftsr.maxIter)
15471547
>>> model = aftsr.fit(df)
15481548
>>> model.setFeaturesCol("features")
1549-
AFTSurvivalRegression...
1549+
AFTSurvivalRegressionModel...
15501550
>>> model.predict(Vectors.dense(6.3))
15511551
1.0
15521552
>>> model.predictQuantiles(Vectors.dense(6.3))
@@ -1881,7 +1881,7 @@ class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionPar
18811881
>>> glr.clear(glr.maxIter)
18821882
>>> model = glr.fit(df)
18831883
>>> model.setFeaturesCol("features")
1884-
GeneralizedLinearRegression...
1884+
GeneralizedLinearRegressionModel...
18851885
>>> model.getMaxIter()
18861886
25
18871887
>>> model.getAggregationDepth()

python/pyspark/ml/tree.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ def predictLeaf(self, value):
5656
"""
5757
return self._call_java("predictLeaf", value)
5858

59-
def __repr__(self):
60-
return self._call_java("toString")
61-
6259

6360
class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
6461
"""
@@ -208,9 +205,6 @@ def predictLeaf(self, value):
208205
"""
209206
return self._call_java("predictLeaf", value)
210207

211-
def __repr__(self):
212-
return self._call_java("toString")
213-
214208

215209
class _TreeEnsembleParams(_DecisionTreeParams):
216210
"""

python/pyspark/ml/wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,9 @@ def __init__(self, java_model=None):
372372

373373
self._resetUid(java_model.uid())
374374

375+
def __repr__(self):
376+
return self._call_java("toString")
377+
375378

376379
@inherit_doc
377380
class _JavaPredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):

0 commit comments

Comments
 (0)