Skip to content

Commit 43785de

Browse files
committed
[SPARK-50925][ML][PYTHON][CONNECT] Support GeneralizedLinearRegression on Connect
### What changes were proposed in this pull request? Support GeneralizedLinearRegression on Connect ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithm supported on connect ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49673 from zhengruifeng/ml_connect_glr. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent d0b1b0b commit 43785de

File tree

6 files changed

+137
-1
lines changed

6 files changed

+137
-1
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ org.apache.spark.ml.classification.GBTClassifier
2828

2929
# regression
3030
org.apache.spark.ml.regression.LinearRegression
31+
org.apache.spark.ml.regression.GeneralizedLinearRegression
3132
org.apache.spark.ml.regression.DecisionTreeRegressor
3233
org.apache.spark.ml.regression.RandomForestRegressor
3334
org.apache.spark.ml.regression.GBTRegressor

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ org.apache.spark.ml.classification.GBTClassificationModel
4444

4545
# regression
4646
org.apache.spark.ml.regression.LinearRegressionModel
47+
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
4748
org.apache.spark.ml.regression.DecisionTreeRegressionModel
4849
org.apache.spark.ml.regression.RandomForestRegressionModel
4950
org.apache.spark.ml.regression.GBTRegressionModel

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,8 @@ class GeneralizedLinearRegressionModel private[ml] (
10091009
with GeneralizedLinearRegressionBase with MLWritable
10101010
with HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary] {
10111011

1012+
private[ml] def this() = this(Identifiable.randomUID("glm"), Vectors.empty, Double.NaN)
1013+
10121014
/**
10131015
* Sets the link prediction (linear predictor) column name.
10141016
*
@@ -1182,7 +1184,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
11821184
@Since("2.0.0")
11831185
class GeneralizedLinearRegressionSummary private[regression] (
11841186
dataset: Dataset[_],
1185-
origModel: GeneralizedLinearRegressionModel) extends Serializable {
1187+
origModel: GeneralizedLinearRegressionModel) extends Summary with Serializable {
11861188

11871189
import GeneralizedLinearRegression._
11881190

python/pyspark/ml/regression.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,6 +2795,7 @@ class GeneralizedLinearRegressionSummary(JavaWrapper):
27952795

27962796
@property
27972797
@since("2.0.0")
2798+
@try_remote_attribute_relation
27982799
def predictions(self) -> DataFrame:
27992800
"""
28002801
Predictions output by the model's `transform` method.
@@ -2850,6 +2851,7 @@ def residualDegreeOfFreedomNull(self) -> int:
28502851
"""
28512852
return self._call_java("residualDegreeOfFreedomNull")
28522853

2854+
@try_remote_attribute_relation
28532855
def residuals(self, residualsType: str = "deviance") -> DataFrame:
28542856
"""
28552857
Get the residuals of the fitted model by type.

python/pyspark/ml/tests/test_regression.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from pyspark.ml.regression import (
2626
LinearRegression,
2727
LinearRegressionModel,
28+
GeneralizedLinearRegression,
29+
GeneralizedLinearRegressionModel,
30+
GeneralizedLinearRegressionSummary,
31+
GeneralizedLinearRegressionTrainingSummary,
2832
LinearRegressionSummary,
2933
LinearRegressionTrainingSummary,
3034
DecisionTreeRegressor,
@@ -163,6 +167,104 @@ def test_linear_regression(self):
163167
model2 = LinearRegressionModel.load(d)
164168
self.assertEqual(str(model), str(model2))
165169

170+
def test_generalized_linear_regression(self):
171+
spark = self.spark
172+
df = (
173+
spark.createDataFrame(
174+
[
175+
(1, 1.0, Vectors.dense(0.0, 0.0)),
176+
(2, 1.0, Vectors.dense(1.0, 2.0)),
177+
(3, 2.0, Vectors.dense(0.0, 0.0)),
178+
(4, 2.0, Vectors.dense(1.0, 1.0)),
179+
],
180+
["index", "label", "features"],
181+
)
182+
.coalesce(1)
183+
.sortWithinPartitions("index")
184+
.select("label", "features")
185+
)
186+
187+
glr = GeneralizedLinearRegression(
188+
family="gaussian",
189+
link="identity",
190+
linkPredictionCol="p",
191+
)
192+
glr.setRegParam(0.1)
193+
glr.setMaxIter(1)
194+
self.assertEqual(glr.getFamily(), "gaussian")
195+
self.assertEqual(glr.getLink(), "identity")
196+
self.assertEqual(glr.getLinkPredictionCol(), "p")
197+
self.assertEqual(glr.getRegParam(), 0.1)
198+
self.assertEqual(glr.getMaxIter(), 1)
199+
200+
model = glr.fit(df)
201+
self.assertTrue(np.allclose(model.intercept, 1.543859649122807, atol=1e-4), model.intercept)
202+
self.assertTrue(
203+
np.allclose(model.coefficients.toArray(), [0.43859649, -0.35087719], atol=1e-4),
204+
model.coefficients,
205+
)
206+
self.assertEqual(model.numFeatures, 2)
207+
208+
vec = Vectors.dense(1.0, 2.0)
209+
pred = model.predict(vec)
210+
self.assertTrue(np.allclose(pred, 1.280701754385965, atol=1e-4), pred)
211+
212+
expected_cols = ["label", "features", "p", "prediction"]
213+
214+
output = model.transform(df)
215+
self.assertEqual(output.columns, expected_cols)
216+
self.assertEqual(output.count(), 4)
217+
218+
# Model summary
219+
self.assertTrue(model.hasSummary)
220+
221+
summary = model.summary
222+
self.assertIsInstance(summary, GeneralizedLinearRegressionSummary)
223+
self.assertIsInstance(summary, GeneralizedLinearRegressionTrainingSummary)
224+
self.assertEqual(summary.numIterations, 1)
225+
self.assertEqual(summary.numInstances, 4)
226+
self.assertEqual(summary.rank, 3)
227+
self.assertTrue(
228+
np.allclose(
229+
summary.tValues,
230+
[0.3725037662281711, -0.49418209022924164, 2.6589353685797654],
231+
atol=1e-4,
232+
),
233+
summary.tValues,
234+
)
235+
self.assertTrue(
236+
np.allclose(
237+
summary.pValues,
238+
[0.7729938686180984, 0.707802691825973, 0.22900885781807023],
239+
atol=1e-4,
240+
),
241+
summary.pValues,
242+
)
243+
self.assertEqual(summary.predictions.columns, expected_cols)
244+
self.assertEqual(summary.predictions.count(), 4)
245+
self.assertEqual(summary.residuals().columns, ["devianceResiduals"])
246+
self.assertEqual(summary.residuals().count(), 4)
247+
248+
summary2 = model.evaluate(df)
249+
self.assertIsInstance(summary2, GeneralizedLinearRegressionSummary)
250+
self.assertNotIsInstance(summary2, GeneralizedLinearRegressionTrainingSummary)
251+
self.assertEqual(summary2.numInstances, 4)
252+
self.assertEqual(summary2.rank, 3)
253+
self.assertEqual(summary.predictions.columns, expected_cols)
254+
self.assertEqual(summary.predictions.count(), 4)
255+
self.assertEqual(summary2.residuals().columns, ["devianceResiduals"])
256+
self.assertEqual(summary2.residuals().count(), 4)
257+
258+
# Model save & load
259+
with tempfile.TemporaryDirectory(prefix="generalized_linear_regression") as d:
260+
glr.write().overwrite().save(d)
261+
glr2 = GeneralizedLinearRegression.load(d)
262+
self.assertEqual(str(glr), str(glr2))
263+
264+
model.write().overwrite().save(d)
265+
model2 = GeneralizedLinearRegressionModel.load(d)
266+
self.assertEqual(str(model), str(model2))
267+
166268
def test_decision_tree_regressor(self):
167269
df = self.df
168270

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,34 @@ private[ml] object MLUtils {
532532
(classOf[BinaryLogisticRegressionSummary], Set("scoreCol")),
533533

534534
// Regression Models
535+
(
536+
classOf[GeneralizedLinearRegressionModel],
537+
Set("intercept", "coefficients", "numFeatures", "evaluate")),
538+
(
539+
classOf[GeneralizedLinearRegressionSummary],
540+
Set(
541+
"aic",
542+
"degreesOfFreedom",
543+
"deviance",
544+
"dispersion",
545+
"nullDeviance",
546+
"numInstances",
547+
"predictionCol",
548+
"predictions",
549+
"rank",
550+
"residualDegreeOfFreedom",
551+
"residualDegreeOfFreedomNull",
552+
"residuals")),
553+
(
554+
classOf[GeneralizedLinearRegressionTrainingSummary],
555+
Set(
556+
"numIterations",
557+
"solver",
558+
"tValues",
559+
"pValues",
560+
"coefficientStandardErrors",
561+
"coefficientsWithStatistics",
562+
"toString")),
535563
(classOf[LinearRegressionModel], Set("intercept", "coefficients", "scale", "evaluate")),
536564
(
537565
classOf[LinearRegressionSummary],

0 commit comments

Comments
 (0)