Skip to content

Commit e0437e0

Browse files
committed
[SPARK-50920][ML][PYTHON][CONNECT] Support NaiveBayes on Connect
### What changes were proposed in this pull request? Support NaiveBayes on Connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithm supported on connect ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49672 from zhengruifeng/ml_connect_nb. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent acdda8a commit e0437e0

File tree

6 files changed

+71
-1
lines changed

6 files changed

+71
-1
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,8 @@ object SparseMatrix {
10481048
@Since("2.0.0")
10491049
object Matrices {
10501050

1051+
private[ml] val empty = new DenseMatrix(0, 0, Array.emptyDoubleArray)
1052+
10511053
private[ml] def fromVectors(vectors: Seq[Vector]): Matrix = {
10521054
val numRows = vectors.length
10531055
val numCols = vectors.head.size

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# So register the supported estimator here if you're trying to add a new one.
2020

2121
# classification
22+
org.apache.spark.ml.classification.NaiveBayes
2223
org.apache.spark.ml.classification.LinearSVC
2324
org.apache.spark.ml.classification.LogisticRegression
2425
org.apache.spark.ml.classification.DecisionTreeClassifier

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ org.apache.spark.ml.feature.HashingTF
3535

3636
########### Model for loading
3737
# classification
38+
org.apache.spark.ml.classification.NaiveBayesModel
3839
org.apache.spark.ml.classification.LinearSVCModel
3940
org.apache.spark.ml.classification.LogisticRegressionModel
4041
org.apache.spark.ml.classification.DecisionTreeClassificationModel

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ class NaiveBayesModel private[ml] (
401401

402402
import NaiveBayes._
403403

404+
private[ml] def this() = this(Identifiable.randomUID("nb"),
405+
Vectors.empty, Matrices.empty, Matrices.empty)
406+
404407
/**
405408
* mllib NaiveBayes is a wrapper of ml implementation currently.
406409
* Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes labels,

python/pyspark/ml/tests/test_classification.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import numpy as np
2323

2424
from pyspark.ml.linalg import Vectors, Matrices
25-
from pyspark.sql import SparkSession, DataFrame
25+
from pyspark.sql import SparkSession, DataFrame, Row
2626
from pyspark.ml.classification import (
27+
NaiveBayes,
28+
NaiveBayesModel,
2729
LinearSVC,
2830
LinearSVCModel,
2931
LinearSVCSummary,
@@ -46,6 +48,66 @@
4648

4749

4850
class ClassificationTestsMixin:
51+
def test_naive_bayes(self):
52+
spark = self.spark
53+
df = spark.createDataFrame(
54+
[
55+
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
56+
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
57+
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
58+
]
59+
)
60+
61+
nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
62+
self.assertEqual(nb.getSmoothing(), 1.0)
63+
self.assertEqual(nb.getModelType(), "multinomial")
64+
self.assertEqual(nb.getWeightCol(), "weight")
65+
66+
model = nb.fit(df)
67+
self.assertEqual(model.numClasses, 2)
68+
self.assertEqual(model.numFeatures, 2)
69+
self.assertTrue(
70+
np.allclose(model.pi.toArray(), [-0.81093022, -0.58778666], atol=1e-4), model.pi
71+
)
72+
self.assertTrue(
73+
np.allclose(
74+
model.theta.toArray(),
75+
[[-0.91629073, -0.51082562], [-0.40546511, -1.09861229]],
76+
atol=1e-4,
77+
),
78+
model.theta,
79+
)
80+
self.assertTrue(np.allclose(model.sigma.toArray(), [], atol=1e-4), model.sigma)
81+
82+
vec = Vectors.dense(0.0, 5.0)
83+
self.assertEqual(model.predict(vec), 0.0)
84+
pred = model.predictRaw(vec)
85+
self.assertTrue(np.allclose(pred.toArray(), [-3.36505834, -6.08084811], atol=1e-4), pred)
86+
pred = model.predictProbability(vec)
87+
self.assertTrue(np.allclose(pred.toArray(), [0.93795196, 0.06204804], atol=1e-4), pred)
88+
89+
output = model.transform(df)
90+
expected_cols = [
91+
"label",
92+
"weight",
93+
"features",
94+
"rawPrediction",
95+
"probability",
96+
"prediction",
97+
]
98+
self.assertEqual(output.columns, expected_cols)
99+
self.assertEqual(output.count(), 3)
100+
101+
# Model save & load
102+
with tempfile.TemporaryDirectory(prefix="naive_bayes") as d:
103+
nb.write().overwrite().save(d)
104+
nb2 = NaiveBayes.load(d)
105+
self.assertEqual(str(nb), str(nb2))
106+
107+
model.write().overwrite().save(d)
108+
model2 = NaiveBayesModel.load(d)
109+
self.assertEqual(str(model), str(model2))
110+
49111
def test_binomial_logistic_regression_with_bound(self):
50112
df = self.spark.createDataFrame(
51113
[

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ private[ml] object MLUtils {
523523
(classOf[GBTRegressionModel], Set("featureImportances", "evaluateEachIteration")),
524524

525525
// Classification Models
526+
(classOf[NaiveBayesModel], Set("pi", "theta", "sigma")),
526527
(classOf[LinearSVCModel], Set("intercept", "coefficients", "evaluate")),
527528
(
528529
classOf[LogisticRegressionModel],

0 commit comments

Comments
 (0)