Skip to content

Commit 147519e

Browse files
aarnatymboehm7
authored andcommitted
[SYSTEMDS-3664] New simplification rewrite rev(seq())
This patch introduces a new simplification rewrite for reversing a sequence rev(seq(1,n)) --> seq(n,1). Closes #2242.
1 parent 686be39 commit 147519e

File tree

3 files changed

+198
-0
lines changed

3 files changed

+198
-0
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
156156
hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE
157157
hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
158158
hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1)
159+
hi = simplifyReverseSequenceStep(hop, hi, i); //e.g., rev(seq(1,n,2)) -> rev(n,1,-2)
159160
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
160161
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
161162
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
@@ -824,6 +825,59 @@ private static Hop simplifyReverseSequence( Hop parent, Hop hi, int pos )
824825

825826
return hi;
826827
}
828+
829+
private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int pos) {
830+
if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
831+
&& hi.getInput(0) instanceof DataGenOp
832+
&& ((DataGenOp) hi.getInput(0)).getOp() == OpOpDG.SEQ
833+
&& hi.getInput(0).getParent().size() == 1) // only one consumer
834+
{
835+
DataGenOp seq = (DataGenOp) hi.getInput(0);
836+
Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
837+
Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
838+
Hop incr = seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR));
839+
840+
if (from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp) {
841+
double fromVal = ((LiteralOp) from).getDoubleValue();
842+
double toVal = ((LiteralOp) to).getDoubleValue();
843+
double incrVal = ((LiteralOp) incr).getDoubleValue();
844+
845+
// Skip if increment is zero (invalid sequence)
846+
if (Math.abs(incrVal) < 1e-10)
847+
return hi;
848+
849+
boolean isValidDirection = false;
850+
851+
// Checking direction compatibility
852+
if ((incrVal > 0 && fromVal <= toVal) || (incrVal < 0 && fromVal >= toVal)) {
853+
isValidDirection = true;
854+
}
855+
856+
if (isValidDirection) {
857+
// Calculate the number of elements and the last element
858+
int numValues = (int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1;
859+
double lastVal = fromVal + (numValues - 1) * incrVal;
860+
861+
// Create a new sequence based on actual last value
862+
LiteralOp newFrom = new LiteralOp(lastVal);
863+
LiteralOp newTo = new LiteralOp(fromVal);
864+
LiteralOp newIncr = new LiteralOp(-incrVal);
865+
866+
// Replace the parameters
867+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom);
868+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo);
869+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr);
870+
871+
// Replace the old sequence with the new one
872+
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
873+
HopRewriteUtils.cleanupUnreferenced(hi, seq);
874+
hi = seq;
875+
LOG.debug("Applied simplifyReverseSequenceStep (line " + hi.getBeginLine() + ").");
876+
}
877+
}
878+
}
879+
return hi;
880+
}
827881

828882
private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
829883
{
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.hops.OptimizerUtils;
23+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
24+
import org.apache.sysds.test.AutomatedTestBase;
25+
import org.apache.sysds.test.TestConfiguration;
26+
import org.apache.sysds.test.TestUtils;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
30+
public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase {
31+
private static final String TEST_NAME1 = "RewriteSimplifyReverseSequenceStep";
32+
33+
private static final String TEST_DIR = "functions/rewrite/";
34+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/";
35+
36+
@Override
37+
public void setUp() {
38+
TestUtils.clearAssertionInformation();
39+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
40+
}
41+
42+
@Test
43+
public void testRewriteReverseSeqStep() {
44+
testRewriteReverseSeq(TEST_NAME1, true);
45+
}
46+
47+
@Test
48+
public void testNoRewriteReverseSeqStep() {
49+
testRewriteReverseSeq(TEST_NAME1, false);
50+
}
51+
52+
private void testRewriteReverseSeq(String testname, boolean rewrites) {
53+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
54+
int rows = 10;
55+
56+
try {
57+
TestConfiguration config = getTestConfiguration(testname);
58+
loadTestConfiguration(config);
59+
60+
String HOME = SCRIPT_DIR + TEST_DIR;
61+
fullDMLScriptName = HOME + testname + ".dml";
62+
programArgs = new String[]{"-stats", "-args", String.valueOf(rows), output("Scalar")};
63+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
64+
65+
runTest(true, false, null, -1);
66+
67+
// Calculate expected sums for each sequence
68+
double sum1 = calculateSum(0, rows-1, 1); // A1 = rev(seq(0, rows-1, 1))
69+
double sum2 = calculateSum(0, rows, 2); // A2 = rev(seq(0, rows, 2))
70+
double sum3 = calculateSum(2, rows, 2); // A3 = rev(seq(2, rows, 2))
71+
double sum4 = calculateSum(0, 100, 5); // A4 = rev(seq(0, 100, 5))
72+
double sum5 = calculateSum(15, 5, -0.5); // A5 = rev(seq(15, 5, -0.5))
73+
74+
double expected = sum1 + sum2 + sum3 + sum4 + sum5;
75+
76+
double ret = readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 1)).doubleValue();
77+
78+
Assert.assertEquals("Incorrect sum computed", expected, ret, 1e-10);
79+
80+
if (rewrites) {
81+
// With bidirectional rewrite, REV operations should be removed
82+
Assert.assertFalse("Rewrite should have removed REV operation!",
83+
heavyHittersContainsString("rev"));
84+
}
85+
}
86+
finally {
87+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
88+
}
89+
}
90+
91+
// Helper method to calculate sum of a sequence
92+
private double calculateSum(double from, double to, double incr) {
93+
double sum = 0;
94+
int n = 0;
95+
96+
if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) {
97+
// Calculate number of elements in the sequence
98+
n = (int)Math.floor(Math.abs((to - from) / incr)) + 1;
99+
100+
// Calculate the last element in the sequence
101+
double last = from + (n - 1) * incr;
102+
103+
// Use arithmetic sequence sum formula: n * (first + last) / 2
104+
sum = n * (from + last) / 2;
105+
}
106+
107+
return sum;
108+
}
109+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
rows = as.integer($1)
23+
24+
# Original test sequences (positive increments)
25+
A1 = rev(seq(0, rows-1, 1)) # Should become seq(rows-1, 0, -1)
26+
A2 = rev(seq(0, rows, 2)) # Should become seq(rows, 0, -2)
27+
A3 = rev(seq(2, rows, 2)) # Should become seq(lastVal, 2, -2) where lastVal is the last value in the sequence
28+
A4 = rev(seq(0, 100, 5)) # Should become seq(100, 0, -5)
29+
A5 = rev(seq(15, 5, -0.5)) # Should become seq(5, 15, 0.5)
30+
31+
# Sum all sequences
32+
R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5)
33+
34+
# Output
35+
write(R, $2)

0 commit comments

Comments
 (0)