Skip to content

Commit 293a0f2

Browse files
tengpengdbtsai
authored andcommitted
[Spark-24024][ML] Fix poisson deviance calculations in GLM to handle y = 0
## What changes were proposed in this pull request? It is reported by Spark users that the deviance calculation for poisson regression does not handle y = 0. Thus, the correct model summary cannot be obtained. The user has confirmed the the issue is in ``` override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (y * math.log(y / mu) - (y - mu)) } when y = 0. ``` The user also mentioned there are many other places he believe we should check the same thing. However, no other changes are needed, including Gamma distribution. ## How was this patch tested? Add a comparison with R deviance calculation to the existing unit test. Author: Teng Peng <[email protected]> Closes apache#21125 from tengpeng/Spark24024GLM.
1 parent afbdf42 commit 293a0f2

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
471471

472472
private[regression] val epsilon: Double = 1E-16
473473

474+
private[regression] def ylogy(y: Double, mu: Double): Double = {
475+
if (y == 0) 0.0 else y * math.log(y / mu)
476+
}
477+
474478
/**
475479
* Wrapper of family and link combination used in the model.
476480
*/
@@ -725,10 +729,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
725729

726730
override def variance(mu: Double): Double = mu * (1.0 - mu)
727731

728-
private def ylogy(y: Double, mu: Double): Double = {
729-
if (y == 0) 0.0 else y * math.log(y / mu)
730-
}
731-
732732
override def deviance(y: Double, mu: Double, weight: Double): Double = {
733733
2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu))
734734
}
@@ -783,7 +783,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
783783
override def variance(mu: Double): Double = mu
784784

785785
override def deviance(y: Double, mu: Double, weight: Double): Double = {
786-
2.0 * weight * (y * math.log(y / mu) - (y - mu))
786+
2.0 * weight * (ylogy(y, mu) - (y - mu))
787787
}
788788

789789
override def aic(

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
493493
}
494494
[1] -0.0457441 -0.6833928
495495
[1] 1.8121235 -0.1747493 -0.5815417
496+
497+
R code for deivance calculation:
498+
data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1))
499+
summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance
500+
[1] 3.70055
501+
summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance
502+
[1] 3.809296
496503
*/
497504
val expected = Seq(
498505
Vectors.dense(0.0, -0.0457441, -0.6833928),
499506
Vectors.dense(1.8121235, -0.1747493, -0.5815417))
500507

508+
val residualDeviancesR = Array(3.809296, 3.70055)
509+
501510
import GeneralizedLinearRegression._
502511

503512
var idx = 0
@@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
510519
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
511520
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
512521
s"$link link and fitIntercept = $fitIntercept (with zero values).")
522+
assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3)
513523
idx += 1
514524
}
515525
}

0 commit comments

Comments
 (0)