Skip to content

Commit 082cf89

Browse files
committed
[SYSTEMDS-3804] New rewrite for reverse sequences
This patch adds a new rewrite rev(seq(1,n)) -> seq(n,1), a pattern we recently saw in a script on vectorized time series forecasting.
1 parent b5b6f37 commit 082cf89

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
153153
hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps)
154154
hi = simplifyCTableWithConstMatrixInputs(hi); //e.g., table(X, matrix(1,...)) -> table(X, 1)
155155
hi = removeUnnecessaryCTable(hop, hi, i); //e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and sum(table(X, Y)) -> nrow(X)
156-
hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE
156+
hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE
157157
hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
158+
hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1)
158159
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
159160
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
160161
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
@@ -798,6 +799,28 @@ private static Hop simplifyReverseOperation( Hop parent, Hop hi, int pos )
798799

799800
return hi;
800801
}
802+
803+
private static Hop simplifyReverseSequence( Hop parent, Hop hi, int pos )
804+
{
805+
if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
806+
&& HopRewriteUtils.isBasic1NSequence(hi.getInput(0))
807+
&& hi.getInput(0).getParent().size() == 1) //only consumer
808+
{
809+
DataGenOp seq = (DataGenOp) hi.getInput(0);
810+
Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
811+
Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
812+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), to);
813+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), from);
814+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), new LiteralOp(-1));
815+
816+
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
817+
HopRewriteUtils.cleanupUnreferenced(hi, seq);
818+
hi = seq;
819+
LOG.debug("Applied simplifyReverseSequence (line "+hi.getBeginLine()+").");
820+
}
821+
822+
return hi;
823+
}
801824

802825
private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
803826
{
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.junit.Assert;
23+
import org.junit.Test;
24+
25+
import org.apache.sysds.hops.OptimizerUtils;
26+
import org.apache.sysds.test.AutomatedTestBase;
27+
import org.apache.sysds.test.TestConfiguration;
28+
import org.apache.sysds.test.TestUtils;
29+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
30+
31+
public class RewriteRemoveUnnecessaryRevTest extends AutomatedTestBase
32+
{
33+
private static final String TEST_NAME1 = "RewriteRemoveUnnecessaryRev";
34+
35+
private static final String TEST_DIR = "functions/rewrite/";
36+
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumBinaryMult.class.getSimpleName() + "/";
37+
38+
@Override
39+
public void setUp() {
40+
TestUtils.clearAssertionInformation();
41+
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
42+
}
43+
44+
@Test
45+
public void testRemoveSeqRevRewrite() {
46+
testRewriteRemoveSeqRev( TEST_NAME1, true );
47+
}
48+
49+
@Test
50+
public void testRemoveSeqRevNoRewrite() {
51+
testRewriteRemoveSeqRev( TEST_NAME1, false );
52+
}
53+
54+
private void testRewriteRemoveSeqRev( String testname, boolean rewrites )
55+
{
56+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
57+
int rows = 1001;
58+
59+
try
60+
{
61+
TestConfiguration config = getTestConfiguration(testname);
62+
loadTestConfiguration(config);
63+
64+
String HOME = SCRIPT_DIR + TEST_DIR;
65+
fullDMLScriptName = HOME + testname + ".dml";
66+
programArgs = new String[]{ "-stats","-args", String.valueOf(rows), output("Scalar") };
67+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
68+
69+
runTest(true, false, null, -1);
70+
71+
//compare scalars
72+
int ret = (int)readDMLScalarFromOutputDir("Scalar").get(new CellIndex(1,1)).doubleValue();
73+
Assert.assertEquals(ret, rows*(rows+1)/2);
74+
if( rewrites )
75+
Assert.assertFalse(heavyHittersContainsString("rev"));
76+
}
77+
finally {
78+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
79+
}
80+
}
81+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
rows = $1;
23+
24+
# to be rewritten to: seq(rows,1)
25+
X = rev(seq(1,rows))
26+
27+
while(FALSE){}
28+
29+
R = sum(X);
30+
write(R, $2)
31+

0 commit comments

Comments
 (0)