Skip to content

Commit 0e4242d

Browse files
committed
[SPARK-50923][SPARK-50927][ML][PYTHON][CONNECT] Support FMClassifier and FMRegressor on Connect
### What changes were proposed in this pull request? Support FMClassifier and FMRegressor on Connect ### Why are the changes needed? for parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49685 from zhengruifeng/ml_connect_fm. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 9d0e888 commit 0e4242d

File tree

7 files changed

+179
-0
lines changed

7 files changed

+179
-0
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
org.apache.spark.ml.classification.NaiveBayes
2323
org.apache.spark.ml.classification.LinearSVC
2424
org.apache.spark.ml.classification.LogisticRegression
25+
org.apache.spark.ml.classification.FMClassifier
2526
org.apache.spark.ml.classification.MultilayerPerceptronClassifier
2627
org.apache.spark.ml.classification.DecisionTreeClassifier
2728
org.apache.spark.ml.classification.RandomForestClassifier
@@ -32,6 +33,7 @@ org.apache.spark.ml.regression.AFTSurvivalRegression
3233
org.apache.spark.ml.regression.IsotonicRegression
3334
org.apache.spark.ml.regression.LinearRegression
3435
org.apache.spark.ml.regression.GeneralizedLinearRegression
36+
org.apache.spark.ml.regression.FMRegressor
3537
org.apache.spark.ml.regression.DecisionTreeRegressor
3638
org.apache.spark.ml.regression.RandomForestRegressor
3739
org.apache.spark.ml.regression.GBTRegressor

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ org.apache.spark.ml.feature.HashingTF
3838
org.apache.spark.ml.classification.NaiveBayesModel
3939
org.apache.spark.ml.classification.LinearSVCModel
4040
org.apache.spark.ml.classification.LogisticRegressionModel
41+
org.apache.spark.ml.classification.FMClassificationModel
4142
org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel
4243
org.apache.spark.ml.classification.DecisionTreeClassificationModel
4344
org.apache.spark.ml.classification.RandomForestClassificationModel
@@ -48,6 +49,7 @@ org.apache.spark.ml.regression.AFTSurvivalRegressionModel
4849
org.apache.spark.ml.regression.IsotonicRegressionModel
4950
org.apache.spark.ml.regression.LinearRegressionModel
5051
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
52+
org.apache.spark.ml.regression.FMRegressionModel
5153
org.apache.spark.ml.regression.DecisionTreeRegressionModel
5254
org.apache.spark.ml.regression.RandomForestRegressionModel
5355
org.apache.spark.ml.regression.GBTRegressionModel

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ class FMClassificationModel private[classification] (
259259
with FMClassifierParams with MLWritable
260260
with HasTrainingSummary[FMClassificationTrainingSummary]{
261261

262+
private[ml] def this() = this(Identifiable.randomUID("fmc"),
263+
Double.NaN, Vectors.empty, Matrices.empty)
264+
262265
@Since("3.0.0")
263266
override val numClasses: Int = 2
264267

mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,9 @@ class FMRegressionModel private[regression] (
461461
extends RegressionModel[Vector, FMRegressionModel]
462462
with FMRegressorParams with MLWritable {
463463

464+
private[ml] def this() = this(Identifiable.randomUID("fmr"),
465+
Double.NaN, Vectors.empty, Matrices.empty)
466+
464467
@Since("3.0.0")
465468
override val numFeatures: Int = linear.size
466469

python/pyspark/ml/tests/test_classification.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
LogisticRegressionModel,
3535
LogisticRegressionSummary,
3636
BinaryLogisticRegressionSummary,
37+
FMClassifier,
38+
FMClassificationModel,
39+
FMClassificationSummary,
40+
FMClassificationTrainingSummary,
3741
DecisionTreeClassifier,
3842
DecisionTreeClassificationModel,
3943
RandomForestClassifier,
@@ -447,6 +451,104 @@ def test_linear_svc(self):
447451
model2 = LinearSVCModel.load(d)
448452
self.assertEqual(str(model), str(model2))
449453

454+
def test_factorization_machine(self):
455+
spark = self.spark
456+
df = (
457+
spark.createDataFrame(
458+
[
459+
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
460+
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
461+
(1.0, 3.0, Vectors.dense(2.0, 1.0)),
462+
(0.0, 4.0, Vectors.dense(3.0, 3.0)),
463+
],
464+
["label", "weight", "features"],
465+
)
466+
.coalesce(1)
467+
.sortWithinPartitions("weight")
468+
)
469+
470+
fm = FMClassifier(factorSize=2, maxIter=1, regParam=1.0, seed=1)
471+
self.assertEqual(fm.getFactorSize(), 2)
472+
self.assertEqual(fm.getMaxIter(), 1)
473+
self.assertEqual(fm.getRegParam(), 1.0)
474+
self.assertEqual(fm.getSeed(), 1)
475+
476+
model = fm.fit(df)
477+
self.assertEqual(fm.uid, model.uid)
478+
self.assertEqual(model.numClasses, 2)
479+
self.assertEqual(model.numFeatures, 2)
480+
self.assertTrue(
481+
np.allclose(model.intercept, 0.9999070647126924, atol=1e-4), model.intercept
482+
)
483+
self.assertTrue(
484+
np.allclose(
485+
model.linear.toArray(), [-0.999999959956255, 0.9999999201744205], atol=1e-4
486+
),
487+
model.linear,
488+
)
489+
self.assertTrue(
490+
np.allclose(
491+
model.factors.toArray(),
492+
[[0.99999918, 0.99999858], [-0.99999943, 0.99999854]],
493+
atol=1e-4,
494+
),
495+
model.factors,
496+
)
497+
498+
vec = Vectors.dense(0.0, 5.0)
499+
pred = model.predict(vec)
500+
self.assertEqual(pred, 1.0)
501+
pred = model.predictRaw(vec)
502+
self.assertTrue(
503+
np.allclose(pred.toArray(), [-5.9999066655847955, 5.9999066655847955], atol=1e-4),
504+
pred,
505+
)
506+
pred = model.predictProbability(vec)
507+
self.assertTrue(
508+
np.allclose(pred.toArray(), [0.002472853377527451, 0.9975271466224725], atol=1e-4),
509+
pred,
510+
)
511+
512+
output = model.transform(df)
513+
expected_cols = [
514+
"label",
515+
"weight",
516+
"features",
517+
"rawPrediction",
518+
"probability",
519+
"prediction",
520+
]
521+
self.assertEqual(output.columns, expected_cols)
522+
self.assertEqual(output.count(), 4)
523+
524+
# model summary
525+
self.assertTrue(model.hasSummary)
526+
summary = model.summary()
527+
self.assertIsInstance(summary, FMClassificationSummary)
528+
self.assertIsInstance(summary, FMClassificationTrainingSummary)
529+
self.assertEqual(summary.labels, [0.0, 1.0])
530+
self.assertEqual(summary.accuracy, 0.25)
531+
self.assertEqual(summary.areaUnderROC, 0.5)
532+
self.assertEqual(summary.predictions.columns, expected_cols)
533+
534+
summary2 = model.evaluate(df)
535+
self.assertIsInstance(summary2, FMClassificationSummary)
536+
self.assertFalse(isinstance(summary2, FMClassificationTrainingSummary))
537+
self.assertEqual(summary2.labels, [0.0, 1.0])
538+
self.assertEqual(summary2.accuracy, 0.25)
539+
self.assertEqual(summary2.areaUnderROC, 0.5)
540+
self.assertEqual(summary2.predictions.columns, expected_cols)
541+
542+
# Model save & load
543+
with tempfile.TemporaryDirectory(prefix="factorization_machine") as d:
544+
fm.write().overwrite().save(d)
545+
fm2 = FMClassifier.load(d)
546+
self.assertEqual(str(fm), str(fm2))
547+
548+
model.write().overwrite().save(d)
549+
model2 = FMClassificationModel.load(d)
550+
self.assertEqual(str(model), str(model2))
551+
450552
def test_decision_tree_classifier(self):
451553
df = (
452554
self.spark.createDataFrame(

python/pyspark/ml/tests/test_regression.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
GeneralizedLinearRegressionTrainingSummary,
3636
LinearRegressionSummary,
3737
LinearRegressionTrainingSummary,
38+
FMRegressor,
39+
FMRegressionModel,
3840
DecisionTreeRegressor,
3941
DecisionTreeRegressionModel,
4042
RandomForestRegressor,
@@ -368,6 +370,69 @@ def test_generalized_linear_regression(self):
368370
model2 = GeneralizedLinearRegressionModel.load(d)
369371
self.assertEqual(str(model), str(model2))
370372

373+
def test_factorization_machine(self):
374+
spark = self.spark
375+
df = (
376+
spark.createDataFrame(
377+
[
378+
(1, 1.0, Vectors.dense(0.0, 0.0)),
379+
(2, 1.0, Vectors.dense(1.0, 2.0)),
380+
(3, 2.0, Vectors.dense(0.0, 0.0)),
381+
(4, 2.0, Vectors.dense(1.0, 1.0)),
382+
],
383+
["index", "label", "features"],
384+
)
385+
.coalesce(1)
386+
.sortWithinPartitions("index")
387+
.select("label", "features")
388+
)
389+
390+
fm = FMRegressor(factorSize=2, maxIter=1, regParam=1.0, seed=1)
391+
self.assertEqual(fm.getFactorSize(), 2)
392+
self.assertEqual(fm.getMaxIter(), 1)
393+
self.assertEqual(fm.getRegParam(), 1.0)
394+
self.assertEqual(fm.getSeed(), 1)
395+
396+
model = fm.fit(df)
397+
self.assertEqual(fm.uid, model.uid)
398+
self.assertEqual(model.numFeatures, 2)
399+
self.assertTrue(
400+
np.allclose(model.intercept, 0.9999999966668874, atol=1e-4), model.intercept
401+
)
402+
self.assertTrue(
403+
np.allclose(
404+
model.linear.toArray(), [0.9999999933342161, 0.9999999950008276], atol=1e-4
405+
),
406+
model.linear,
407+
)
408+
self.assertTrue(
409+
np.allclose(
410+
model.factors.toArray(),
411+
[[-0.99999954, -0.9999992], [0.99999968, -0.99999918]],
412+
atol=1e-4,
413+
),
414+
model.factors,
415+
)
416+
417+
vec = Vectors.dense(0.0, 5.0)
418+
pred = model.predict(vec)
419+
self.assertTrue(np.allclose(pred, 5.999999971671025, atol=1e-4), pred)
420+
421+
output = model.transform(df)
422+
expected_cols = ["label", "features", "prediction"]
423+
self.assertEqual(output.columns, expected_cols)
424+
self.assertEqual(output.count(), 4)
425+
426+
# Model save & load
427+
with tempfile.TemporaryDirectory(prefix="factorization_machine") as d:
428+
fm.write().overwrite().save(d)
429+
fm2 = FMRegressor.load(d)
430+
self.assertEqual(str(fm), str(fm2))
431+
432+
model.write().overwrite().save(d)
433+
model2 = FMRegressionModel.load(d)
434+
self.assertEqual(str(model), str(model2))
435+
371436
def test_decision_tree_regressor(self):
372437
df = self.df
373438

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ private[ml] object MLUtils {
530530
Set("intercept", "coefficients", "interceptVector", "coefficientMatrix", "evaluate")),
531531
(classOf[LogisticRegressionSummary], Set("probabilityCol", "featuresCol")),
532532
(classOf[BinaryLogisticRegressionSummary], Set("scoreCol")),
533+
(classOf[FMClassificationModel], Set("intercept", "linear", "factors", "evaluate")),
533534
(classOf[MultilayerPerceptronClassificationModel], Set("weights", "evaluate")),
534535

535536
// Regression Models
@@ -589,6 +590,7 @@ private[ml] object MLUtils {
589590
"tValues",
590591
"pValues")),
591592
(classOf[LinearRegressionTrainingSummary], Set("objectiveHistory", "totalIterations")),
593+
(classOf[FMRegressionModel], Set("intercept", "linear", "factors")),
592594

593595
// Clustering Models
594596
(classOf[KMeansModel], Set("predict", "numFeatures", "clusterCenterMatrix")),

0 commit comments

Comments
 (0)