Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml

import org.apache.spark.ml.regression.ArimaRegression
import org.apache.spark.sql.SparkSession

object ArimaRegressionExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("ARIMA Example").getOrCreate()
import spark.implicits._

val tsData = Seq(1.2, 2.3, 3.1, 4.0, 5.5).toDF("y")

val arima = new ArimaRegression().setP(1).setD(0).setQ(1)
val model = arima.fit(tsData)

val result = model.transform(tsData)
result.show()

spark.stop()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.regression

import org.apache.spark.ml.param._

private[regression] trait ArimaParams extends Params {
final val p = new IntParam(this, "p", "AR order")
final val d = new IntParam(this, "d", "Differencing order")
final val q = new IntParam(this, "q", "MA order")

setDefault(p -> 1, d -> 0, q -> 1)

def getP: Int = $(p)
def getD: Int = $(d)
def getQ: Int = $(q)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.regression

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.DefaultParamsWritable
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class ArimaRegression(override val uid: String)
extends Estimator[ArimaRegressionModel]
with ArimaParams
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("arimaReg"))

def setP(value: Int): this.type = set(p, value)
def setD(value: Int): this.type = set(d, value)
def setQ(value: Int): this.type = set(q, value)

override def fit(dataset: Dataset[_]): ArimaRegressionModel = {
// Dummy: assumes data is ordered with one feature column "y"
val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect()

// [TO DO]: Replace with actual ARIMA fitting logic
val model = new ArimaRegressionModel(uid)
.setParent(this)
model
}

override def copy(extra: ParamMap): ArimaRegression = defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {
require(schema.fieldNames.contains("y"), "Dataset must contain 'y' column.")
schema.add(StructField("prediction", DoubleType, false))
}
}

object ArimaRegression extends DefaultParamsReadable[ArimaRegression]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.regression

import org.apache.spark.ml._
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class ArimaRegressionModel(override val uid: String)
extends Model[ArimaRegressionModel]
with ArimaParams
with MLWritable {

override def copy(extra: ParamMap): ArimaRegressionModel = {
val copied = new ArimaRegressionModel(uid)
copyValues(copied, extra).setParent(parent)
}

override def transform(dataset: Dataset[_]): DataFrame = {
// Dummy prediction logic — just copy y as prediction
dataset.withColumn("prediction", col("y"))
}

override def transformSchema(schema: StructType): StructType = {
schema.add(StructField("prediction", DoubleType, false))
}
}

object ArimaRegressionModel extends MLReadable[ArimaRegressionModel] {
override def read: MLReader[ArimaRegressionModel] = new DefaultParamsReader[ArimaRegressionModel]
override def load(path: String): ArimaRegressionModel = super.load(path)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame

class ArimaRegressionSuite extends SparkFunSuite {

test("basic model fit and transform") {
val spark = sparkSession
import spark.implicits._

val df = Seq(1.0, 2.0, 3.0, 4.0).toDF("y")
val arima = new ArimaRegression().setP(1).setD(0).setQ(1)
val model = arima.fit(df)

val transformed = model.transform(df)
assert(transformed.columns.contains("prediction"))
assert(transformed.count() == df.count())
}
}
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.ml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ Regression
RandomForestRegressionModel
FMRegressor
FMRegressionModel
ArimaRegression
ArimaRegressionModel


Statistics
Expand Down
70 changes: 70 additions & 0 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
"RandomForestRegressionModel",
"FMRegressor",
"FMRegressionModel",
"ArimaRegression",
"ArimaRegressionModel"
]


Expand Down Expand Up @@ -146,6 +148,13 @@ class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], metaclass=AB

pass

class ArimaRegressionModel(JavaModel):
"""
Model fitted by :py:class:`ArimaRegression`.

This model supports `.transform()` and optional `.predict()`.
"""
pass

class _LinearRegressionParams(
_PredictorParams,
Expand Down Expand Up @@ -208,6 +217,67 @@ def getEpsilon(self) -> float:
return self.getOrDefault(self.epsilon)


@inherit_doc
class ArimaRegression(JavaEstimator):
"""
ArimaRegression(p=1, d=0, q=1)

ARIMA time series regression model.

Parameters
----------
p : int
Autoregressive order.
d : int
Differencing order.
q : int
Moving average order.

Notes
-----
Requires a column named "y" as the input time series column.
"""

p = Param(Params._dummy(), "p", "Autoregressive order.")
d = Param(Params._dummy(), "d", "Differencing order.")
q = Param(Params._dummy(), "q", "Moving average order.")

def __init__(self, p=1, d=0, q=1):
super(ArimaRegression, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.ArimaRegression", self.uid)
self._setDefault(p=1, d=0, q=1)
kwargs = self._input_kwargs
self.setParams(**kwargs)

def setParams(self, p=1, d=0, q=1):
"""
Set parameters for ArimaRegression.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)

def setP(self, value):
return self._set(p=value)

def getP(self):
return self.getOrDefault(self.p)

def setD(self, value):
return self._set(d=value)

def getD(self):
return self.getOrDefault(self.d)

def setQ(self, value):
return self._set(q=value)

def getQ(self):
return self.getOrDefault(self.q)

def _create_model(self, java_model):
return ArimaRegressionModel(java_model)


@inherit_doc
class LinearRegression(
_JavaRegressor["LinearRegressionModel"],
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/ml/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,56 @@ def test_random_forest_regressor(self):
self.assertEqual(model.toDebugString, model2.toDebugString)


def test_arima_regression(self):
import numpy as np
import tempfile
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import ArimaRegression, ArimaRegressionModel

spark = self.spark

# Time series data in a single column named "y"
df = spark.createDataFrame(
[(1.2,), (2.3,), (3.1,), (4.0,), (5.5,)],
["y"]
)

arima = ArimaRegression(
p=1,
d=0,
q=1,
)

self.assertEqual(arima.getP(), 1)
self.assertEqual(arima.getD(), 0)
self.assertEqual(arima.getQ(), 1)

model = arima.fit(df)
self.assertEqual(model.uid, arima.uid)

output = model.transform(df)
expected_cols = ["y", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 5)

# Predict a single value if API supports it
if hasattr(model, "predict"):
pred = model.predict(3.0)
self.assertIsInstance(pred, float)

# Model save/load
with tempfile.TemporaryDirectory(prefix="arima_regression") as d:
arima_path = d + "/arima"
model_path = d + "/arima_model"

arima.write().overwrite().save(arima_path)
loaded_arima = ArimaRegression.load(arima_path)
self.assertEqual(str(arima), str(loaded_arima))

model.write().overwrite().save(model_path)
loaded_model = ArimaRegressionModel.load(model_path)
self.assertEqual(str(model), str(loaded_model))

class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase):
pass

Expand Down