Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 4bab8f5

Browse files
chunshengjiyanboliang
authored andcommitted
[SPARK-21856] Add probability and rawPrediction to MLPC for Python
Probability and rawPrediction has been added to MultilayerPerceptronClassifier for Python Add unit test. Author: Chunsheng Ji <[email protected]> Closes apache#19172 from chunshengji/SPARK-21856.
1 parent 828fab0 commit 4bab8f5

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

python/pyspark/ml/classification.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,8 @@ def theta(self):
13561356
@inherit_doc
13571357
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
13581358
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
1359-
JavaMLWritable, JavaMLReadable):
1359+
JavaMLWritable, JavaMLReadable, HasProbabilityCol,
1360+
HasRawPredictionCol):
13601361
"""
13611362
Classifier trainer based on the Multilayer Perceptron.
13621363
Each layer has sigmoid activation function, output layer has softmax.
@@ -1425,11 +1426,13 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
14251426
@keyword_only
14261427
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
14271428
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
1428-
solver="l-bfgs", initialWeights=None):
1429+
solver="l-bfgs", initialWeights=None, probabilityCol="probability",
1430+
rawPredicitionCol="rawPrediction"):
14291431
"""
14301432
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
14311433
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
1432-
solver="l-bfgs", initialWeights=None)
1434+
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
1435+
rawPredicitionCol="rawPrediction")
14331436
"""
14341437
super(MultilayerPerceptronClassifier, self).__init__()
14351438
self._java_obj = self._new_java_obj(
@@ -1442,11 +1445,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
14421445
@since("1.6.0")
14431446
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
14441447
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
1445-
solver="l-bfgs", initialWeights=None):
1448+
solver="l-bfgs", initialWeights=None, probabilityCol="probability",
1449+
rawPredicitionCol="rawPrediction"):
14461450
"""
14471451
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
14481452
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
1449-
solver="l-bfgs", initialWeights=None)
1453+
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
1454+
rawPredicitionCol="rawPrediction"):
14501455
Sets params for MultilayerPerceptronClassifier.
14511456
"""
14521457
kwargs = self._input_kwargs

python/pyspark/ml/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,26 @@ def test_multinomial_logistic_regression_with_bound(self):
16551655
np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4))
16561656

16571657

1658+
class MultilayerPerceptronClassifierTest(SparkSessionTestCase):
1659+
1660+
def test_raw_and_probability_prediction(self):
1661+
1662+
data_path = "data/mllib/sample_multiclass_classification_data.txt"
1663+
df = self.spark.read.format("libsvm").load(data_path)
1664+
1665+
mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3],
1666+
blockSize=128, seed=123)
1667+
model = mlp.fit(df)
1668+
test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF()
1669+
result = model.transform(test).head()
1670+
expected_prediction = 2.0
1671+
expected_probability = [0.0, 0.0, 1.0]
1672+
expected_rawPrediction = [57.3955, -124.5462, 67.9943]
1673+
self.assertTrue(result.prediction, expected_prediction)
1674+
self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4))
1675+
self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4))
1676+
1677+
16581678
class FPGrowthTests(SparkSessionTestCase):
16591679
def setUp(self):
16601680
super(FPGrowthTests, self).setUp()

0 commit comments

Comments
 (0)