From e79db2ce60ae3a8dd1a7e55bfbccb25a8288967d Mon Sep 17 00:00:00 2001 From: anandexplore <131127991+anandexplore@users.noreply.github.com> Date: Sun, 5 Oct 2025 03:46:11 -0500 Subject: [PATCH] [SPARK-53803][ML][Feature] Add ArimaRegression for time series forecasting in MLlib Add ArimaRegression for time series forecasting in MLlib --- .../examples/ml/ArimaRegressionExample.scala | 38 ++++++++++ .../spark/ml/regression/ArimaParams.scala | 31 ++++++++ .../spark/ml/regression/ArimaRegression.scala | 58 +++++++++++++++ .../ml/regression/ArimaRegressionModel.scala | 50 +++++++++++++ .../ml/regression/ArimaRegressionSuite.scala | 38 ++++++++++ python/docs/source/reference/pyspark.ml.rst | 2 + python/pyspark/ml/regression.py | 70 +++++++++++++++++++ python/pyspark/ml/tests/test_regression.py | 50 +++++++++++++ 8 files changed, 337 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala new file mode 100644 index 0000000000000..24952bc647c30 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ArimaRegressionExample.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.ml.regression.ArimaRegression +import org.apache.spark.sql.SparkSession + +object ArimaRegressionExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder.appName("ARIMA Example").getOrCreate() + import spark.implicits._ + + val tsData = Seq(1.2, 2.3, 3.1, 4.0, 5.5).toDF("y") + + val arima = new ArimaRegression().setP(1).setD(0).setQ(1) + val model = arima.fit(tsData) + + val result = model.transform(tsData) + result.show() + + spark.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala new file mode 100644 index 0000000000000..05aea9b40b712 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaParams.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.regression + +import org.apache.spark.ml.param._ + +private[regression] trait ArimaParams extends Params { + final val p = new IntParam(this, "p", "AR order") + final val d = new IntParam(this, "d", "Differencing order") + final val q = new IntParam(this, "q", "MA order") + + setDefault(p -> 1, d -> 0, q -> 1) + + def getP: Int = $(p) + def getD: Int = $(d) + def getQ: Int = $(q) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala new file mode 100644 index 0000000000000..d65816c229932 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegression.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.regression + +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.DefaultParamsWritable +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class ArimaRegression(override val uid: String) + extends Estimator[ArimaRegressionModel] + with ArimaParams + with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("arimaReg")) + + def setP(value: Int): this.type = set(p, value) + def setD(value: Int): this.type = set(d, value) + def setQ(value: Int): this.type = set(q, value) + + override def fit(dataset: Dataset[_]): ArimaRegressionModel = { + // Dummy: assumes data is ordered with one feature column "y" + val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect() + + // [TO DO]: Replace with actual ARIMA fitting logic + val model = new ArimaRegressionModel(uid) + .setParent(this) + model + } + + override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra) + + override def transformSchema(schema: StructType): StructType = { + require(schema.fieldNames.contains("y"), "Dataset must contain 'y' column.") + schema.add(StructField("prediction", DoubleType, false)) + } +} + +object ArimaRegression extends DefaultParamsReadable[ArimaRegression] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala new file mode 100644 index 0000000000000..cf7b42403f068 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/ArimaRegressionModel.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.ml._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class ArimaRegressionModel(override val uid: String) + extends Model[ArimaRegressionModel] + with ArimaParams + with MLWritable { + + override def copy(extra: ParamMap): ArimaRegressionModel = { + val copied = new ArimaRegressionModel(uid) + copyValues(copied, extra).setParent(parent) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + // Dummy prediction logic — just copy y as prediction + dataset.withColumn("prediction", col("y")) + } + + override def transformSchema(schema: StructType): StructType = { + schema.add(StructField("prediction", DoubleType, false)) + } +} + +object ArimaRegressionModel extends MLReadable[ArimaRegressionModel] { + override def read: MLReader[ArimaRegressionModel] = new DefaultParamsReader[ArimaRegressionModel] + override def load(path: String): ArimaRegressionModel = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala new file mode 100644 index 0000000000000..191fd3eb2e867 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/ArimaRegressionSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.sql.DataFrame + +class ArimaRegressionSuite extends SparkFunSuite { + + test("basic model fit and transform") { + val spark = sparkSession + import spark.implicits._ + + val df = Seq(1.0, 2.0, 3.0, 4.0).toDF("y") + val arima = new ArimaRegression().setP(1).setD(0).setQ(1) + val model = arima.fit(df) + + val transformed = model.transform(df) + assert(transformed.columns.contains("prediction")) + assert(transformed.count() == df.count()) + } +} diff --git a/python/docs/source/reference/pyspark.ml.rst b/python/docs/source/reference/pyspark.ml.rst index 1dfb63aa1dbd1..9d05d50d61137 100644 --- a/python/docs/source/reference/pyspark.ml.rst +++ b/python/docs/source/reference/pyspark.ml.rst @@ -265,6 +265,8 @@ Regression RandomForestRegressionModel FMRegressor FMRegressionModel + ArimaRegression + ArimaRegressionModel Statistics diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index ce97b98f6665c..dbf0be7b00656 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -103,6 +103,8 @@ "RandomForestRegressionModel", "FMRegressor", "FMRegressionModel", + "ArimaRegression", + "ArimaRegressionModel" ] @@ -146,6 +148,13 @@ class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=AB pass +class ArimaRegressionModel(JavaModel): + """ + Model fitted by :py:class:`ArimaRegression`. + + This model supports `.transform()` and optional `.predict()`. + """ + pass class _LinearRegressionParams( _PredictorParams, @@ -208,6 +217,67 @@ def getEpsilon(self) -> float: return self.getOrDefault(self.epsilon) +@inherit_doc +class ArimaRegression(JavaEstimator): + """ + ArimaRegression(p=1, d=0, q=1) + + ARIMA time series regression model. + + Parameters + ---------- + p : int + Autoregressive order. + d : int + Differencing order. + q : int + Moving average order. + + Notes + ----- + Requires a column named "y" as the input time series column. + """ + + p = Param(Params._dummy(), "p", "Autoregressive order.") + d = Param(Params._dummy(), "d", "Differencing order.") + q = Param(Params._dummy(), "q", "Moving average order.") + + def __init__(self, p=1, d=0, q=1): + super(ArimaRegression, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.ArimaRegression", self.uid) + self._setDefault(p=1, d=0, q=1) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + def setParams(self, p=1, d=0, q=1): + """ + Set parameters for ArimaRegression. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + def setP(self, value): + return self._set(p=value) + + def getP(self): + return self.getOrDefault(self.p) + + def setD(self, value): + return self._set(d=value) + + def getD(self): + return self.getOrDefault(self.d) + + def setQ(self, value): + return self._set(q=value) + + def getQ(self): + return self.getOrDefault(self.q) + + def _create_model(self, java_model): + return ArimaRegressionModel(java_model) + + @inherit_doc class LinearRegression( _JavaRegressor["LinearRegressionModel"], diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index 52688fdd63cf2..fd216a0e6cd2e 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -696,6 +696,56 @@ def test_random_forest_regressor(self): self.assertEqual(model.toDebugString, model2.toDebugString) +def test_arima_regression(self): + import numpy as np + import tempfile + from pyspark.ml.linalg import Vectors + from pyspark.ml.regression import ArimaRegression, ArimaRegressionModel + + spark = self.spark + + # Time series data in a single column named "y" + df = spark.createDataFrame( + [(1.2,), (2.3,), (3.1,), (4.0,), (5.5,)], + ["y"] + ) + + arima = ArimaRegression( + p=1, + d=0, + q=1, + ) + + self.assertEqual(arima.getP(), 1) + self.assertEqual(arima.getD(), 0) + self.assertEqual(arima.getQ(), 1) + + model = arima.fit(df) + self.assertEqual(model.uid, arima.uid) + + output = model.transform(df) + expected_cols = ["y", "prediction"] + self.assertEqual(output.columns, expected_cols) + self.assertEqual(output.count(), 5) + + # Predict a single value if API supports it + if hasattr(model, "predict"): + pred = model.predict(3.0) + self.assertIsInstance(pred, float) + + # Model save/load + with tempfile.TemporaryDirectory(prefix="arima_regression") as d: + arima_path = d + "/arima" + model_path = d + "/arima_model" + + arima.write().overwrite().save(arima_path) + loaded_arima = ArimaRegression.load(arima_path) + self.assertEqual(str(arima), str(loaded_arima)) + + model.write().overwrite().save(model_path) + loaded_model = ArimaRegressionModel.load(model_path) + self.assertEqual(str(model), str(loaded_model)) + class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase): pass