@@ -1451,7 +1451,7 @@ def test_glr_summary(self):
1451
1451
sameSummary = model .evaluate (df )
1452
1452
self .assertAlmostEqual (sameSummary .deviance , s .deviance )
1453
1453
1454
- def test_logistic_regression_summary (self ):
1454
+ def test_binary_logistic_regression_summary (self ):
1455
1455
df = self .spark .createDataFrame ([(1.0 , 2.0 , Vectors .dense (1.0 )),
1456
1456
(0.0 , 2.0 , Vectors .sparse (1 , [], []))],
1457
1457
["label" , "weight" , "features" ])
@@ -1464,20 +1464,73 @@ def test_logistic_regression_summary(self):
1464
1464
self .assertEqual (s .probabilityCol , "probability" )
1465
1465
self .assertEqual (s .labelCol , "label" )
1466
1466
self .assertEqual (s .featuresCol , "features" )
1467
+ self .assertEqual (s .predictionCol , "prediction" )
1467
1468
objHist = s .objectiveHistory
1468
1469
self .assertTrue (isinstance (objHist , list ) and isinstance (objHist [0 ], float ))
1469
1470
self .assertGreater (s .totalIterations , 0 )
1471
+ self .assertTrue (isinstance (s .labels , list ))
1472
+ self .assertTrue (isinstance (s .truePositiveRateByLabel , list ))
1473
+ self .assertTrue (isinstance (s .falsePositiveRateByLabel , list ))
1474
+ self .assertTrue (isinstance (s .precisionByLabel , list ))
1475
+ self .assertTrue (isinstance (s .recallByLabel , list ))
1476
+ self .assertTrue (isinstance (s .fMeasureByLabel (), list ))
1477
+ self .assertTrue (isinstance (s .fMeasureByLabel (1.0 ), list ))
1470
1478
self .assertTrue (isinstance (s .roc , DataFrame ))
1471
1479
self .assertAlmostEqual (s .areaUnderROC , 1.0 , 2 )
1472
1480
self .assertTrue (isinstance (s .pr , DataFrame ))
1473
1481
self .assertTrue (isinstance (s .fMeasureByThreshold , DataFrame ))
1474
1482
self .assertTrue (isinstance (s .precisionByThreshold , DataFrame ))
1475
1483
self .assertTrue (isinstance (s .recallByThreshold , DataFrame ))
1484
+ self .assertAlmostEqual (s .accuracy , 1.0 , 2 )
1485
+ self .assertAlmostEqual (s .weightedTruePositiveRate , 1.0 , 2 )
1486
+ self .assertAlmostEqual (s .weightedFalsePositiveRate , 0.0 , 2 )
1487
+ self .assertAlmostEqual (s .weightedRecall , 1.0 , 2 )
1488
+ self .assertAlmostEqual (s .weightedPrecision , 1.0 , 2 )
1489
+ self .assertAlmostEqual (s .weightedFMeasure (), 1.0 , 2 )
1490
+ self .assertAlmostEqual (s .weightedFMeasure (1.0 ), 1.0 , 2 )
1476
1491
# test evaluation (with training dataset) produces a summary with same values
1477
1492
# one check is enough to verify a summary is returned, Scala version runs full test
1478
1493
sameSummary = model .evaluate (df )
1479
1494
self .assertAlmostEqual (sameSummary .areaUnderROC , s .areaUnderROC )
1480
1495
1496
+ def test_multiclass_logistic_regression_summary (self ):
1497
+ df = self .spark .createDataFrame ([(1.0 , 2.0 , Vectors .dense (1.0 )),
1498
+ (0.0 , 2.0 , Vectors .sparse (1 , [], [])),
1499
+ (2.0 , 2.0 , Vectors .dense (2.0 )),
1500
+ (2.0 , 2.0 , Vectors .dense (1.9 ))],
1501
+ ["label" , "weight" , "features" ])
1502
+ lr = LogisticRegression (maxIter = 5 , regParam = 0.01 , weightCol = "weight" , fitIntercept = False )
1503
+ model = lr .fit (df )
1504
+ self .assertTrue (model .hasSummary )
1505
+ s = model .summary
1506
+ # test that api is callable and returns expected types
1507
+ self .assertTrue (isinstance (s .predictions , DataFrame ))
1508
+ self .assertEqual (s .probabilityCol , "probability" )
1509
+ self .assertEqual (s .labelCol , "label" )
1510
+ self .assertEqual (s .featuresCol , "features" )
1511
+ self .assertEqual (s .predictionCol , "prediction" )
1512
+ objHist = s .objectiveHistory
1513
+ self .assertTrue (isinstance (objHist , list ) and isinstance (objHist [0 ], float ))
1514
+ self .assertGreater (s .totalIterations , 0 )
1515
+ self .assertTrue (isinstance (s .labels , list ))
1516
+ self .assertTrue (isinstance (s .truePositiveRateByLabel , list ))
1517
+ self .assertTrue (isinstance (s .falsePositiveRateByLabel , list ))
1518
+ self .assertTrue (isinstance (s .precisionByLabel , list ))
1519
+ self .assertTrue (isinstance (s .recallByLabel , list ))
1520
+ self .assertTrue (isinstance (s .fMeasureByLabel (), list ))
1521
+ self .assertTrue (isinstance (s .fMeasureByLabel (1.0 ), list ))
1522
+ self .assertAlmostEqual (s .accuracy , 0.75 , 2 )
1523
+ self .assertAlmostEqual (s .weightedTruePositiveRate , 0.75 , 2 )
1524
+ self .assertAlmostEqual (s .weightedFalsePositiveRate , 0.25 , 2 )
1525
+ self .assertAlmostEqual (s .weightedRecall , 0.75 , 2 )
1526
+ self .assertAlmostEqual (s .weightedPrecision , 0.583 , 2 )
1527
+ self .assertAlmostEqual (s .weightedFMeasure (), 0.65 , 2 )
1528
+ self .assertAlmostEqual (s .weightedFMeasure (1.0 ), 0.65 , 2 )
1529
+ # test evaluation (with training dataset) produces a summary with same values
1530
+ # one check is enough to verify a summary is returned, Scala version runs full test
1531
+ sameSummary = model .evaluate (df )
1532
+ self .assertAlmostEqual (sameSummary .accuracy , s .accuracy )
1533
+
1481
1534
def test_gaussian_mixture_summary (self ):
1482
1535
data = [(Vectors .dense (1.0 ),), (Vectors .dense (5.0 ),), (Vectors .dense (10.0 ),),
1483
1536
(Vectors .sparse (1 , [], []),)]
0 commit comments