Skip to content

Commit 4021354

Browse files
committed
[SPARK-30044][ML] MNB/CNB/BNB use empty sigma matrix instead of null
### What changes were proposed in this pull request? MNB/CNB/BNB use empty sigma matrix instead of null ### Why are the changes needed? 1,Using empty sigma matrix will simplify the impl 2,I am reviewing FM impl these days, FMModels have optional bias and linear part. It seems more reasonable to set optional part an empty vector/matrix or zero value than `null` ### Does this PR introduce any user-facing change? yes, sigma from `null` to empty matrix ### How was this patch tested? updated testsuites Closes apache#26679 from zhengruifeng/nb_use_empty_sigma. Authored-by: zhengruifeng <[email protected]> Signed-off-by: zhengruifeng <[email protected]>
1 parent 332e593 commit 4021354

File tree

3 files changed

+21
-29
lines changed

3 files changed

+21
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.ml.classification
1919

2020
import org.apache.hadoop.fs.Path
2121
import org.json4s.DefaultFormats
22-
import org.json4s.jackson.JsonMethods._
2322

2423
import org.apache.spark.annotation.Since
2524
import org.apache.spark.ml.PredictorParams
@@ -243,12 +242,12 @@ class NaiveBayes @Since("1.5.0") (
243242
$(modelType) match {
244243
case Multinomial | Bernoulli =>
245244
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
246-
new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
245+
new NaiveBayesModel(uid, pi.compressed, theta.compressed, Matrices.zeros(0, 0))
247246
.setOldLabels(labelArray)
248247
case Complement =>
249248
// Since the CNB compute the coefficient in a complement way.
250249
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray.map(v => -v), true)
251-
new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
250+
new NaiveBayesModel(uid, pi.compressed, theta.compressed, Matrices.zeros(0, 0))
252251
}
253252
}
254253

@@ -575,8 +574,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
575574
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
576575
import NaiveBayes._
577576

578-
private case class Data(pi: Vector, theta: Matrix)
579-
private case class GaussianData(pi: Vector, theta: Matrix, sigma: Matrix)
577+
private case class Data(pi: Vector, theta: Matrix, sigma: Matrix)
580578

581579
override protected def saveImpl(path: String): Unit = {
582580
// Save metadata and Params
@@ -585,21 +583,17 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
585583

586584
instance.getModelType match {
587585
case Multinomial | Bernoulli | Complement =>
588-
// Save model data: pi, theta
589-
require(instance.sigma == null)
590-
val data = Data(instance.pi, instance.theta)
591-
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
592-
586+
require(instance.sigma.numRows == 0 && instance.sigma.numCols == 0)
593587
case Gaussian =>
594-
require(instance.sigma != null)
595-
val data = GaussianData(instance.pi, instance.theta, instance.sigma)
596-
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
588+
require(instance.sigma.numRows != 0 && instance.sigma.numCols != 0)
597589
}
590+
591+
val data = Data(instance.pi, instance.theta, instance.sigma)
592+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
598593
}
599594
}
600595

601596
private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
602-
import NaiveBayes._
603597

604598
/** Checked against metadata when loading model */
605599
private val className = classOf[NaiveBayesModel].getName
@@ -608,19 +602,17 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
608602
implicit val format = DefaultFormats
609603
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
610604
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
611-
val modelTypeJson = metadata.getParamValue("modelType")
612-
val modelType = Param.jsonDecode[String](compact(render(modelTypeJson)))
613605

614606
val dataPath = new Path(path, "data").toString
615607
val data = sparkSession.read.parquet(dataPath)
616608
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
617609

618-
val model = if (major.toInt < 3 || modelType != Gaussian) {
610+
val model = if (major.toInt < 3) {
619611
val Row(pi: Vector, theta: Matrix) =
620612
MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
621613
.select("pi", "theta")
622614
.head()
623-
new NaiveBayesModel(metadata.uid, pi, theta, null)
615+
new NaiveBayesModel(metadata.uid, pi, theta, Matrices.zeros(0, 0))
624616
} else {
625617
val Row(pi: Vector, theta: Matrix, sigma: Matrix) =
626618
MLUtils.convertMatrixColumnsToML(vecConverted, "theta", "sigma")

mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
9696
assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
9797
Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
9898
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
99-
if (sigmaData == null) {
100-
assert(model.sigma == null, "sigma mismatch")
99+
if (sigmaData === Matrices.zeros(0, 0)) {
100+
assert(model.sigma === Matrices.zeros(0, 0), "sigma mismatch")
101101
} else {
102102
assert(model.sigma.map(math.exp) ~== sigmaData.map(math.exp) absTol 0.05,
103103
"sigma mismatch")
@@ -166,7 +166,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
166166
ParamsSuite.checkParams(new NaiveBayes)
167167
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
168168
theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)),
169-
sigma = null)
169+
sigma = Matrices.zeros(0, 0))
170170
ParamsSuite.checkParams(model)
171171
}
172172

@@ -195,7 +195,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
195195
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
196196
val model = nb.fit(testDataset)
197197

198-
validateModelFit(pi, theta, null, model)
198+
validateModelFit(pi, theta, Matrices.zeros(0, 0), model)
199199
assert(model.hasParent)
200200
MLTestingUtils.checkCopyAndUids(nb, model)
201201

@@ -281,7 +281,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
281281
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
282282
val model = nb.fit(testDataset)
283283

284-
validateModelFit(pi, theta, null, model)
284+
validateModelFit(pi, theta, Matrices.zeros(0, 0), model)
285285
assert(model.hasParent)
286286

287287
val validationDataset =
@@ -512,7 +512,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
512512
if (model.getModelType == "gaussian") {
513513
assert(model.sigma === model2.sigma)
514514
} else {
515-
assert(model.sigma === null && model2.sigma === null)
515+
assert(model.sigma === Matrices.zeros(0, 0) && model2.sigma === Matrices.zeros(0, 0))
516516
}
517517
}
518518
val nb = new NaiveBayes()
@@ -531,7 +531,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
531531
nb, spark) { (expected, actual) =>
532532
assert(expected.pi === actual.pi)
533533
assert(expected.theta === actual.theta)
534-
assert(expected.sigma === null && actual.sigma === null)
534+
assert(expected.sigma === Matrices.zeros(0, 0) && actual.sigma === Matrices.zeros(0, 0))
535535
}
536536
}
537537
}

python/pyspark/ml/classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,8 +1934,8 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
19341934
DenseVector([-0.81..., -0.58...])
19351935
>>> model.theta
19361936
DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
1937-
>>> model.sigma == None
1938-
True
1937+
>>> model.sigma
1938+
DenseMatrix(0, 0, [...], ...)
19391939
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
19401940
>>> model.predict(test0.head().features)
19411941
1.0
@@ -1978,8 +1978,8 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
19781978
'complement'
19791979
>>> model5.theta
19801980
DenseMatrix(2, 2, [...], 1)
1981-
>>> model5.sigma == None
1982-
True
1981+
>>> model5.sigma
1982+
DenseMatrix(0, 0, [...], ...)
19831983
19841984
.. versionadded:: 1.5.0
19851985
"""

0 commit comments

Comments
 (0)