Skip to content

Commit 40b8dfa

Browse files
committed
[SPARK-50924][SPARK-50926][ML][PYTHON][CONNECT] Support AFTSurvivalRegression and IsotonicRegression on Connect
### What changes were proposed in this pull request? Support AFTSurvivalRegression and IsotonicRegression on Connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49687 from zhengruifeng/ml_connect_aft. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org> (cherry picked from commit 3ba76bf) Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 343269e commit 40b8dfa

File tree

6 files changed

+117
-0
lines changed

6 files changed

+117
-0
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ org.apache.spark.ml.classification.RandomForestClassifier
2828
org.apache.spark.ml.classification.GBTClassifier
2929

3030
# regression
31+
org.apache.spark.ml.regression.AFTSurvivalRegression
32+
org.apache.spark.ml.regression.IsotonicRegression
3133
org.apache.spark.ml.regression.LinearRegression
3234
org.apache.spark.ml.regression.GeneralizedLinearRegression
3335
org.apache.spark.ml.regression.DecisionTreeRegressor

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ org.apache.spark.ml.classification.RandomForestClassificationModel
4444
org.apache.spark.ml.classification.GBTClassificationModel
4545

4646
# regression
47+
org.apache.spark.ml.regression.AFTSurvivalRegressionModel
48+
org.apache.spark.ml.regression.IsotonicRegressionModel
4749
org.apache.spark.ml.regression.LinearRegressionModel
4850
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
4951
org.apache.spark.ml.regression.DecisionTreeRegressionModel

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ class AFTSurvivalRegressionModel private[ml] (
371371
extends RegressionModel[Vector, AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
372372
with MLWritable {
373373

374+
private[ml] def this() = this(Identifiable.randomUID("aftSurvReg"),
375+
Vectors.empty, Double.NaN, Double.NaN)
376+
374377
@Since("3.0.0")
375378
override def numFeatures: Int = coefficients.size
376379

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ class IsotonicRegressionModel private[ml] (
213213
private val oldModel: MLlibIsotonicRegressionModel)
214214
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable {
215215

216+
private[ml] def this() = this(Identifiable.randomUID("isoReg"), null)
217+
216218
/** @group setParam */
217219
@Since("1.5.0")
218220
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

python/pyspark/ml/tests/test_regression.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from pyspark.ml.linalg import Vectors
2424
from pyspark.sql import SparkSession
2525
from pyspark.ml.regression import (
26+
AFTSurvivalRegression,
27+
AFTSurvivalRegressionModel,
28+
IsotonicRegression,
29+
IsotonicRegressionModel,
2630
LinearRegression,
2731
LinearRegressionModel,
2832
GeneralizedLinearRegression,
@@ -57,6 +61,104 @@ def df(self):
5761
.sortWithinPartitions("weight")
5862
)
5963

64+
def test_aft_survival(self):
65+
spark = self.spark
66+
df = spark.createDataFrame(
67+
[(1.0, Vectors.dense(1.0), 1.0), (1e-40, Vectors.sparse(1, [], []), 0.0)],
68+
["label", "features", "censor"],
69+
)
70+
71+
aft = AFTSurvivalRegression()
72+
aft.setMaxIter(1)
73+
self.assertEqual(aft.getMaxIter(), 1)
74+
75+
model = aft.fit(df)
76+
self.assertEqual(aft.uid, model.uid)
77+
self.assertEqual(model.numFeatures, 1)
78+
self.assertTrue(np.allclose(model.intercept, 0.0, atol=1e-4), model.intercept)
79+
self.assertTrue(
80+
np.allclose(model.coefficients.toArray(), [0.0], atol=1e-4), model.coefficients
81+
)
82+
self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4), model.scale)
83+
84+
vec = Vectors.dense(6.3)
85+
pred = model.predict(vec)
86+
self.assertEqual(pred, 1.0)
87+
pred = model.predictQuantiles(vec)
88+
self.assertTrue(
89+
np.allclose(
90+
pred,
91+
[
92+
0.010050335853501444,
93+
0.051293294387550536,
94+
0.1053605156578263,
95+
0.2876820724517809,
96+
0.6931471805599453,
97+
1.3862943611198906,
98+
2.302585092994046,
99+
2.9957322735539895,
100+
4.60517018598809,
101+
],
102+
atol=1e-4,
103+
),
104+
pred,
105+
)
106+
107+
output = model.transform(df)
108+
expected_cols = ["label", "features", "censor", "prediction"]
109+
self.assertEqual(output.columns, expected_cols)
110+
self.assertEqual(output.count(), 2)
111+
112+
# Model save & load
113+
with tempfile.TemporaryDirectory(prefix="aft_survival") as d:
114+
aft.write().overwrite().save(d)
115+
aft2 = AFTSurvivalRegression.load(d)
116+
self.assertEqual(str(aft), str(aft2))
117+
118+
model.write().overwrite().save(d)
119+
model2 = AFTSurvivalRegressionModel.load(d)
120+
self.assertEqual(str(model), str(model2))
121+
122+
def test_isotonic_regression(self):
123+
spark = self.spark
124+
df = spark.createDataFrame(
125+
[(1.0, Vectors.dense(1.0)), (0.0, Vectors.sparse(1, [], []))], ["label", "features"]
126+
)
127+
128+
ir = IsotonicRegression(
129+
isotonic=True,
130+
featureIndex=0,
131+
)
132+
self.assertTrue(ir.getIsotonic())
133+
self.assertEqual(ir.getFeatureIndex(), 0)
134+
135+
model = ir.fit(df)
136+
self.assertEqual(model.numFeatures, 1)
137+
self.assertTrue(
138+
np.allclose(model.boundaries.toArray(), [0.0, 1.0], atol=1e-4), model.boundaries
139+
)
140+
self.assertTrue(
141+
np.allclose(model.predictions.toArray(), [0.0, 1.0], atol=1e-4), model.predictions
142+
)
143+
144+
pred = model.predict(1.0)
145+
self.assertTrue(np.allclose(pred, 1.0, atol=1e-4), pred)
146+
147+
output = model.transform(df)
148+
expected_cols = ["label", "features", "prediction"]
149+
self.assertEqual(output.columns, expected_cols)
150+
self.assertEqual(output.count(), 2)
151+
152+
# Model save & load
153+
with tempfile.TemporaryDirectory(prefix="isotonic_regression") as d:
154+
ir.write().overwrite().save(d)
155+
ir2 = IsotonicRegression.load(d)
156+
self.assertEqual(str(ir), str(ir2))
157+
158+
model.write().overwrite().save(d)
159+
model2 = IsotonicRegressionModel.load(d)
160+
self.assertEqual(str(model), str(model2))
161+
60162
def test_linear_regression(self):
61163
df = self.df
62164
lr = LinearRegression(

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,12 @@ private[ml] object MLUtils {
533533
(classOf[MultilayerPerceptronClassificationModel], Set("weights", "evaluate")),
534534

535535
// Regression Models
536+
(
537+
classOf[AFTSurvivalRegressionModel],
538+
Set("intercept", "coefficients", "scale", "predictQuantiles")),
539+
(
540+
classOf[IsotonicRegressionModel],
541+
Set("boundaries", "predictions", "numFeatures", "predict")),
536542
(
537543
classOf[GeneralizedLinearRegressionModel],
538544
Set("intercept", "coefficients", "numFeatures", "evaluate")),

0 commit comments

Comments
 (0)