Skip to content

Commit 52ca491

Browse files
committed
[SYSTEMDS-3831] New builtin for vectorized simple exponential smoothing
This patch introduces a new vectorized builtin function for vectorized simple exponential smoothing which largely relies on cumsumprod.
1 parent bea9c96 commit 52ca491

File tree

6 files changed

+150
-2
lines changed

6 files changed

+150
-2
lines changed

scripts/builtin/ses.dml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Builtin function for simple exponential smoothing (SES).
23+
#
24+
# INPUT:
25+
# ------------------------------------------------------------------------------
26+
# x Time series vector [shape: n-by-1]
27+
# h Forecasting horizon
28+
# alpha Smoothing parameter yhat_t = alpha * x_y + (1-alpha) * yhat_t-1
29+
# ------------------------------------------------------------------------------
30+
#
31+
# OUTPUT:
32+
# ------------------------------------------------------------------------------
33+
# yhat Forecasts [shape: h-by-1]
34+
# ------------------------------------------------------------------------------
35+
36+
m_ses = function(Matrix[Double] x, Integer h = 1, Double alpha = 0.5)
37+
return (Matrix[Double] yhat)
38+
{
39+
# check and ensure valid parameters
40+
if(h < 1) {
41+
print("SES: forecasting horizon should be larger one.");
42+
h = 1;
43+
}
44+
if(alpha < 0 | alpha > 1) {
45+
print("SES: smooting parameter should be in [0,1].");
46+
alpha = 0.5;
47+
}
48+
49+
# vectorized forecasting
50+
# weights are 1 for first value and otherwise replicated alpha
51+
# but to compensate alpha*x for the first, we use 1/alpha
52+
w = rbind(as.matrix(1/alpha), matrix(1-alpha,nrow(x)-1,1));
53+
y = cumsumprod(cbind(alpha*x, w));
54+
yhat = matrix(as.scalar(y[nrow(x),1]), h, 1);
55+
}

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ public enum Builtins {
301301
SD("sd", false),
302302
SELVARTHRESH("selectByVarThresh", true),
303303
SEQ("seq", false),
304+
SES("ses", true),
304305
SYMMETRICDIFFERENCE("symmetricDifference", true),
305306
SHAPEXPLAINER("shapExplainer", true),
306307
SHERLOCK("sherlock", true),

src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.apache.sysds.lops.WeightedUnaryMM;
4040
import org.apache.sysds.lops.WeightedUnaryMMR;
4141
import org.apache.sysds.runtime.DMLRuntimeException;
42-
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
4342
import org.apache.sysds.runtime.instructions.cp.CPOperand;
4443
import org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
4544
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;

src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.ArrayList;
2323
import java.util.List;
2424

25-
import org.apache.commons.lang3.NotImplementedException;
2625
import org.apache.commons.lang3.tuple.Pair;
2726
import org.apache.commons.logging.Log;
2827
import org.apache.commons.logging.LogFactory;
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.builtin.part2;
21+
22+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
23+
import org.apache.sysds.test.AutomatedTestBase;
24+
import org.apache.sysds.test.TestConfiguration;
25+
import org.junit.Assert;
26+
import org.junit.Test;
27+
28+
import java.util.HashMap;
29+
30+
public class BuiltinSESTest extends AutomatedTestBase {
31+
private final static String TEST_NAME = "ses";
32+
private final static String TEST_DIR = "functions/builtin/";
33+
private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSESTest.class.getSimpleName() + "/";
34+
35+
private final static int rows = 200;
36+
37+
@Override
38+
public void setUp() {
39+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"y"}));
40+
}
41+
42+
@Test
43+
public void testSES05() {
44+
runSESTest(0.5, 199d);
45+
}
46+
47+
@Test
48+
public void testSES077() {
49+
runSESTest(0.77, 199.7013);
50+
}
51+
52+
@Test
53+
public void testSES10() {
54+
runSESTest(1.0, 200d);
55+
}
56+
57+
private void runSESTest(double alpha, double expected) {
58+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
59+
String HOME = SCRIPT_DIR + TEST_DIR;
60+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
61+
programArgs = new String[] {"-args",
62+
String.valueOf(rows), String.valueOf(alpha), output("y")};
63+
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
64+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("y");
65+
Assert.assertEquals(7, dmlfile.size()); //forecast horizon 7
66+
Assert.assertEquals(expected, dmlfile.get(new CellIndex(1,1)), 1e-3);
67+
}
68+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
23+
x = seq(1, $1);
24+
yhat = ses(x=x, alpha=$2, h=7)
25+
write(yhat, $3)
26+

0 commit comments

Comments
 (0)