Skip to content

Commit 4d51d98

Browse files
committed
[SYSTEMDS-3884] Improved new simplification rewrites (generality)
This patch simplifies the new simplification rewrites, and simultaneously makes them more general (applicable to scalar variables, not just constants - in case of constants, we rely on constant folding for consistent outcomes especially regarding value types).
1 parent 9e649c8 commit 4d51d98

File tree

3 files changed

+38
-86
lines changed

3 files changed

+38
-86
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 34 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
200200
hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
201201
hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
202202
hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B
203-
hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+1)+2 -> t(A)+1+2 -> t(A)+3
203+
hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding
204204
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
205205
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
206206

@@ -213,95 +213,46 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
213213
}
214214

215215
private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) {
216-
if (!(hi instanceof BinaryOp)
217-
|| ((BinaryOp)hi).getOp() != OpOp2.PLUS
218-
|| hi.getDataType() != DataType.MATRIX)
219-
return hi;
220-
221-
BinaryOp bop = (BinaryOp)hi;
222-
223-
ReorgOp tSide = null;
224-
LiteralOp litSide = null;
225-
Hop in0 = bop.getInput().get(0), in1 = bop.getInput().get(1);
226-
if (in0 instanceof ReorgOp && ((ReorgOp)in0).getOp() == ReOrgOp.TRANS
227-
&& in1 instanceof LiteralOp) {
228-
tSide = (ReorgOp)in0;
229-
litSide = (LiteralOp)in1;
230-
}
231-
else if (in1 instanceof ReorgOp && ((ReorgOp)in1).getOp() == ReOrgOp.TRANS
232-
&& in0 instanceof LiteralOp) {
233-
tSide = (ReorgOp)in1;
234-
litSide = (LiteralOp)in0;
235-
}
236-
else
237-
return hi;
238-
239-
//check if only consumer
240-
if (tSide.getParent().size() > 1) {
241-
return hi;
242-
}
243-
244-
Hop inner = tSide.getInput().get(0);
245-
if (!(inner instanceof BinaryOp)
246-
|| ((BinaryOp)inner).getOp() != OpOp2.PLUS
247-
|| inner.getDataType() != DataType.MATRIX)
248-
return hi;
249-
250-
BinaryOp ib = (BinaryOp)inner;
251-
252-
Hop X = null;
253-
LiteralOp lit1 = null;
254-
Hop i0 = ib.getInput().get(0), i1 = ib.getInput().get(1);
255-
if (i0 instanceof LiteralOp) {
256-
lit1 = (LiteralOp)i0;
257-
X = i1;
258-
}
259-
else if (i1 instanceof LiteralOp) {
260-
lit1 = (LiteralOp)i1;
261-
X = i0;
216+
//pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding
217+
if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS)
218+
&& hi.isMatrix() && hi.getInput(1).isScalar()
219+
&& HopRewriteUtils.isReorg(hi.getInput(0), ReOrgOp.TRANS)
220+
&& hi.getInput(0).getParent().size() == 1
221+
&& HopRewriteUtils.isBinary(hi.getInput(0).getInput(0), OpOp2.PLUS)
222+
&& hi.getInput(0).getInput(0).getParent().size() == 1
223+
&& (hi.getInput(0).getInput(0).getInput(0).isScalar()
224+
|| hi.getInput(0).getInput(0).getInput(1).isScalar()))
225+
{
226+
int six = hi.getInput(0).getInput(0).getInput(0).isScalar() ? 0 : 1;
227+
Hop A = hi.getInput(0).getInput(0).getInput(six==0 ? 1 : 0);
228+
Hop s1 = hi.getInput(0).getInput(0).getInput(six);
229+
Hop s2 = hi.getInput(1);
230+
231+
Hop tA = HopRewriteUtils.createTranspose(A);
232+
Hop s12 = HopRewriteUtils.createBinary(s1, s2, OpOp2.PLUS);
233+
Hop newHop = HopRewriteUtils.createBinary(tA, s12, OpOp2.PLUS);
234+
235+
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
236+
HopRewriteUtils.cleanupUnreferenced(hi);
237+
hi = newHop;
238+
239+
LOG.debug("Applied simplifyTransposeAddition (line " + hi.getBeginLine() + ").");
262240
}
263-
else
264-
return hi;
265-
266-
double c = lit1.getDoubleValue() + litSide.getDoubleValue();
267-
268-
ReorgOp newT = HopRewriteUtils.createTranspose(X);
269-
newT.setDim1(tSide.getDim1());
270-
newT.setDim2(tSide.getDim2());
271-
272-
LiteralOp newLit = new LiteralOp(c);
273-
newLit.setDim1(1);
274-
newLit.setDim2(1);
275-
276-
//creating new binaryOp
277-
BinaryOp newPlus = HopRewriteUtils.createBinary(newT, newLit, OpOp2.PLUS);
278-
newPlus.setDim1(bop.getDim1());
279-
newPlus.setDim2(bop.getDim2());
280-
281-
HopRewriteUtils.replaceChildReference(parent, bop, newPlus, pos);
282-
HopRewriteUtils.cleanupUnreferenced(bop, tSide, ib, litSide);
283-
284-
LOG.debug("Applied simplifyTransposeAddition (line " + hi.getBeginLine() + ").");
285-
286-
return newPlus;
241+
242+
return hi;
287243
}
288244

289245
private static Hop simplifyNegatedSubtraction(Hop parent, Hop hi, int pos) {
290-
if (hi instanceof BinaryOp
291-
&& ((BinaryOp) hi).getOp() == OpOp2.MINUS
292-
&& HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 0)
293-
&& hi.getParent().size() == 1
294-
&& hi.getInput().get(1) instanceof BinaryOp
295-
&& ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MINUS
296-
&& hi.getInput().get(1).getParent().size() == 1)
246+
//pattern: -(B-A)->A-B, but only of (B-A) consumed once
247+
if (HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
248+
&& HopRewriteUtils.isLiteralOfValue(hi.getInput(0), 0)
249+
&& HopRewriteUtils.isBinary(hi.getInput(1), OpOp2.MINUS)
250+
&& hi.getInput().get(1).getParent().size() == 1)
297251
{
298-
Hop innerMinus = hi.getInput().get(1);
299-
Hop B = innerMinus.getInput().get(0);
300-
Hop A = innerMinus.getInput().get(1);
252+
Hop B = hi.getInput(1).getInput(0);
253+
Hop A = hi.getInput(1).getInput(1);
301254

302255
BinaryOp newHop = HopRewriteUtils.createBinary(A, B, OpOp2.MINUS);
303-
304-
HopRewriteUtils.copyLineNumbers(hi, newHop);
305256
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
306257
HopRewriteUtils.cleanupUnreferenced(hi);
307258
hi = newHop;

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ private void runRewriteTest(boolean rewriteEnabled) {
8282
HashMap<CellIndex, Double> r = readRMatrixFromExpectedDir("R");
8383

8484
Assert.assertEquals("DML and R outputs do not match", r, dml);
85-
if( rewriteEnabled )
86-
Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("+"));
85+
if( rewriteEnabled ) //no rewrite: 4 (2x2), rewrite: 3 (constant folding, 2x1)
86+
Assert.assertEquals(3, Statistics.getCPHeavyHitterCount("+"));
8787
}
8888
finally {
8989
// Reset optimizer flags

src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
A = read($1);
2222

2323
# Compute t(A+1)+2 which should be rewritten to t(A)+3
24-
result = t(A+1)+2;
24+
for(i in 1:2)
25+
result = t(A+1)+2;
2526

2627
write(result, $2);

0 commit comments

Comments
 (0)