Skip to content

Commit 9b940f7

Browse files
committed
[SYSTEMDS-3797] Fix rewrite for trace on reorg operations
This patch fixes the rewrite for removing unnecessary reorg operations such as sum(t(X)) or sum(rev(X)) for trace aggregations which only consume a subset of values. Furthermore, we generalize this rewrite to now eliminate all reorg operations that are guaranteed to preserve all values (e.g., transpose/reshape/rev/roll, but not for diagM2V and sort with index return). Thanks to Jannik Lindemann for catching this issue.
1 parent 9a318ee commit 9b940f7

File tree

5 files changed

+19
-16
lines changed

5 files changed

+19
-16
lines changed

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,10 @@ public enum ReOrgOp {
751751
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if sizes unknown
752752
RESHAPE, REV, ROLL, SORT, TRANS;
753753

754+
public boolean preservesValues() {
755+
return this != DIAG && this != SORT;
756+
}
757+
754758
@Override
755759
public String toString() {
756760
switch(this) {

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -980,23 +980,21 @@ private static Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos )
980980

981981
private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos )
982982
{
983-
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg
984-
&& hi.getInput().get(0) instanceof ReorgOp ) //reorg operation
983+
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
984+
&& ((AggUnaryOp)hi).getOp() != AggOp.TRACE //full uagg
985+
&& hi.getInput().get(0) instanceof ReorgOp ) //reorg operation
985986
{
986987
ReorgOp rop = (ReorgOp)hi.getInput().get(0);
987-
if( (rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE
988-
|| rop.getOp() == ReOrgOp.REV ) //valid reorg
989-
&& rop.getParent().size()==1 ) //uagg only reorg consumer
988+
if( rop.getOp().preservesValues() //valid reorg
989+
&& rop.getParent().size()==1 ) //uagg only reorg consumer
990990
{
991991
Hop input = rop.getInput().get(0);
992992
HopRewriteUtils.removeAllChildReferences(hi);
993993
HopRewriteUtils.removeAllChildReferences(rop);
994994
HopRewriteUtils.addChildReference(hi, input);
995-
996995
LOG.debug("Applied simplifyUnaryAggReorgOperation");
997996
}
998997
}
999-
1000998
return hi;
1001999
}
10021000

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,12 @@ private void testRewriteTraceMatrixMult(String testname, boolean rewrites) {
8585
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
8686

8787
//check trace operator existence
88-
String uaktrace = "uaktrace";
89-
long numTrace = Statistics.getCPHeavyHitterCount(uaktrace);
90-
91-
if(rewrites)
92-
Assert.assertTrue(numTrace == 0);
93-
else
94-
Assert.assertTrue(numTrace == 1);
95-
88+
long numTrace = Statistics.getCPHeavyHitterCount("uaktrace");
89+
Assert.assertTrue(numTrace == (rewrites ? 1 : 2));
90+
Assert.assertTrue(heavyHittersContainsString("rev"));
9691
}
9792
finally {
9893
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
9994
}
100-
10195
}
10296
}

src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
3636

3737
# Perform the matrix operation
3838
R = sum(diag(A %*% B))
39+
rA = A;
40+
for(i in 1:nrow(rA)) {
41+
rA[,i] = rev(rA[,i])
42+
}
43+
R = R + sum(diag(rA))
3944

4045
# Write the result scalar R
4146
write(R, paste(args[2], "R" ,sep=""))

src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ B = read($2)
2626

2727
# Perform the operation
2828
R = trace(A %*% B)
29+
R = R + trace(rev(A))
2930

3031
# Write the result R
3132
write(R, $3)
33+

0 commit comments

Comments
 (0)