Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 8d8641f

Browse files
Ming Jiangyanboliang
authored andcommitted
[SPARK-21854] Added LogisticRegressionTrainingSummary for MultinomialLogisticRegression in Python API
## What changes were proposed in this pull request? Added LogisticRegressionTrainingSummary for MultinomialLogisticRegression in Python API ## How was this patch tested? Added unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ming Jiang <[email protected]> Author: Ming Jiang <[email protected]> Author: jmwdpk <[email protected]> Closes apache#19185 from jmwdpk/SPARK-21854.
1 parent dcbb229 commit 8d8641f

File tree

3 files changed

+183
-4
lines changed

3 files changed

+183
-4
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,6 +2416,18 @@ class LogisticRegressionSuite
24162416
blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect())
24172417
assert(
24182418
blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect())
2419+
assert(blorSummary.labels === sameBlorSummary.labels)
2420+
assert(blorSummary.truePositiveRateByLabel === sameBlorSummary.truePositiveRateByLabel)
2421+
assert(blorSummary.falsePositiveRateByLabel === sameBlorSummary.falsePositiveRateByLabel)
2422+
assert(blorSummary.precisionByLabel === sameBlorSummary.precisionByLabel)
2423+
assert(blorSummary.recallByLabel === sameBlorSummary.recallByLabel)
2424+
assert(blorSummary.fMeasureByLabel === sameBlorSummary.fMeasureByLabel)
2425+
assert(blorSummary.accuracy === sameBlorSummary.accuracy)
2426+
assert(blorSummary.weightedTruePositiveRate === sameBlorSummary.weightedTruePositiveRate)
2427+
assert(blorSummary.weightedFalsePositiveRate === sameBlorSummary.weightedFalsePositiveRate)
2428+
assert(blorSummary.weightedRecall === sameBlorSummary.weightedRecall)
2429+
assert(blorSummary.weightedPrecision === sameBlorSummary.weightedPrecision)
2430+
assert(blorSummary.weightedFMeasure === sameBlorSummary.weightedFMeasure)
24192431

24202432
lr.setFamily("multinomial")
24212433
val mlorModel = lr.fit(smallMultinomialDataset)

python/pyspark/ml/classification.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,11 @@ def summary(self):
529529
trained on the training set. An exception is thrown if `trainingSummary is None`.
530530
"""
531531
if self.hasSummary:
532-
java_blrt_summary = self._call_java("summary")
533-
# Note: Once multiclass is added, update this to return correct summary
534-
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
532+
java_lrt_summary = self._call_java("summary")
533+
if self.numClasses <= 2:
534+
return BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
535+
else:
536+
return LogisticRegressionTrainingSummary(java_lrt_summary)
535537
else:
536538
raise RuntimeError("No training summary available for this %s" %
537539
self.__class__.__name__)
@@ -586,6 +588,14 @@ def probabilityCol(self):
586588
"""
587589
return self._call_java("probabilityCol")
588590

591+
@property
592+
@since("2.3.0")
593+
def predictionCol(self):
594+
"""
595+
Field in "predictions" which gives the prediction of each class.
596+
"""
597+
return self._call_java("predictionCol")
598+
589599
@property
590600
@since("2.0.0")
591601
def labelCol(self):
@@ -604,6 +614,110 @@ def featuresCol(self):
604614
"""
605615
return self._call_java("featuresCol")
606616

617+
@property
618+
@since("2.3.0")
619+
def labels(self):
620+
"""
621+
Returns the sequence of labels in ascending order. This order matches the order used
622+
in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
623+
624+
Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
625+
training set is missing a label, then all of the arrays over labels
626+
(e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
627+
expected numClasses.
628+
"""
629+
return self._call_java("labels")
630+
631+
@property
632+
@since("2.3.0")
633+
def truePositiveRateByLabel(self):
634+
"""
635+
Returns true positive rate for each label (category).
636+
"""
637+
return self._call_java("truePositiveRateByLabel")
638+
639+
@property
640+
@since("2.3.0")
641+
def falsePositiveRateByLabel(self):
642+
"""
643+
Returns false positive rate for each label (category).
644+
"""
645+
return self._call_java("falsePositiveRateByLabel")
646+
647+
@property
648+
@since("2.3.0")
649+
def precisionByLabel(self):
650+
"""
651+
Returns precision for each label (category).
652+
"""
653+
return self._call_java("precisionByLabel")
654+
655+
@property
656+
@since("2.3.0")
657+
def recallByLabel(self):
658+
"""
659+
Returns recall for each label (category).
660+
"""
661+
return self._call_java("recallByLabel")
662+
663+
@since("2.3.0")
664+
def fMeasureByLabel(self, beta=1.0):
665+
"""
666+
Returns f-measure for each label (category).
667+
"""
668+
return self._call_java("fMeasureByLabel", beta)
669+
670+
@property
671+
@since("2.3.0")
672+
def accuracy(self):
673+
"""
674+
Returns accuracy.
675+
(equals to the total number of correctly classified instances
676+
out of the total number of instances.)
677+
"""
678+
return self._call_java("accuracy")
679+
680+
@property
681+
@since("2.3.0")
682+
def weightedTruePositiveRate(self):
683+
"""
684+
Returns weighted true positive rate.
685+
(equals to precision, recall and f-measure)
686+
"""
687+
return self._call_java("weightedTruePositiveRate")
688+
689+
@property
690+
@since("2.3.0")
691+
def weightedFalsePositiveRate(self):
692+
"""
693+
Returns weighted false positive rate.
694+
"""
695+
return self._call_java("weightedFalsePositiveRate")
696+
697+
@property
698+
@since("2.3.0")
699+
def weightedRecall(self):
700+
"""
701+
Returns weighted averaged recall.
702+
(equals to precision, recall and f-measure)
703+
"""
704+
return self._call_java("weightedRecall")
705+
706+
@property
707+
@since("2.3.0")
708+
def weightedPrecision(self):
709+
"""
710+
Returns weighted averaged precision.
711+
"""
712+
return self._call_java("weightedPrecision")
713+
714+
@since("2.3.0")
715+
def weightedFMeasure(self, beta=1.0):
716+
"""
717+
Returns weighted averaged f-measure.
718+
"""
719+
return self._call_java("weightedFMeasure", beta)
720+
607721

608722
@inherit_doc
609723
class LogisticRegressionTrainingSummary(LogisticRegressionSummary):

python/pyspark/ml/tests.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ def test_glr_summary(self):
14511451
sameSummary = model.evaluate(df)
14521452
self.assertAlmostEqual(sameSummary.deviance, s.deviance)
14531453

1454-
def test_logistic_regression_summary(self):
1454+
def test_binary_logistic_regression_summary(self):
14551455
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
14561456
(0.0, 2.0, Vectors.sparse(1, [], []))],
14571457
["label", "weight", "features"])
@@ -1464,20 +1464,73 @@ def test_logistic_regression_summary(self):
14641464
self.assertEqual(s.probabilityCol, "probability")
14651465
self.assertEqual(s.labelCol, "label")
14661466
self.assertEqual(s.featuresCol, "features")
1467+
self.assertEqual(s.predictionCol, "prediction")
14671468
objHist = s.objectiveHistory
14681469
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
14691470
self.assertGreater(s.totalIterations, 0)
1471+
self.assertTrue(isinstance(s.labels, list))
1472+
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
1473+
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
1474+
self.assertTrue(isinstance(s.precisionByLabel, list))
1475+
self.assertTrue(isinstance(s.recallByLabel, list))
1476+
self.assertTrue(isinstance(s.fMeasureByLabel(), list))
1477+
self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
14701478
self.assertTrue(isinstance(s.roc, DataFrame))
14711479
self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
14721480
self.assertTrue(isinstance(s.pr, DataFrame))
14731481
self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
14741482
self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
14751483
self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
1484+
self.assertAlmostEqual(s.accuracy, 1.0, 2)
1485+
self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
1486+
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
1487+
self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
1488+
self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
1489+
self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
1490+
self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
14761491
# test evaluation (with training dataset) produces a summary with same values
14771492
# one check is enough to verify a summary is returned, Scala version runs full test
14781493
sameSummary = model.evaluate(df)
14791494
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
14801495

1496+
def test_multiclass_logistic_regression_summary(self):
1497+
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
1498+
(0.0, 2.0, Vectors.sparse(1, [], [])),
1499+
(2.0, 2.0, Vectors.dense(2.0)),
1500+
(2.0, 2.0, Vectors.dense(1.9))],
1501+
["label", "weight", "features"])
1502+
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
1503+
model = lr.fit(df)
1504+
self.assertTrue(model.hasSummary)
1505+
s = model.summary
1506+
# test that api is callable and returns expected types
1507+
self.assertTrue(isinstance(s.predictions, DataFrame))
1508+
self.assertEqual(s.probabilityCol, "probability")
1509+
self.assertEqual(s.labelCol, "label")
1510+
self.assertEqual(s.featuresCol, "features")
1511+
self.assertEqual(s.predictionCol, "prediction")
1512+
objHist = s.objectiveHistory
1513+
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
1514+
self.assertGreater(s.totalIterations, 0)
1515+
self.assertTrue(isinstance(s.labels, list))
1516+
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
1517+
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
1518+
self.assertTrue(isinstance(s.precisionByLabel, list))
1519+
self.assertTrue(isinstance(s.recallByLabel, list))
1520+
self.assertTrue(isinstance(s.fMeasureByLabel(), list))
1521+
self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
1522+
self.assertAlmostEqual(s.accuracy, 0.75, 2)
1523+
self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
1524+
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
1525+
self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
1526+
self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
1527+
self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
1528+
self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
1529+
# test evaluation (with training dataset) produces a summary with same values
1530+
# one check is enough to verify a summary is returned, Scala version runs full test
1531+
sameSummary = model.evaluate(df)
1532+
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
1533+
14811534
def test_gaussian_mixture_summary(self):
14821535
data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
14831536
(Vectors.sparse(1, [], []),)]

0 commit comments

Comments
 (0)