Skip to content

Commit f3b638a

Browse files
aarnatymboehm7
authored andcommitted
[SYSTEMDS-3812] Improved rewrites pushdow-sum and rm-reorg
Closes #2176.
1 parent 704b6fb commit f3b638a

File tree

5 files changed

+160
-98
lines changed

5 files changed

+160
-98
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -381,30 +381,28 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)
381381

382382
return hi;
383383
}
384-
385-
private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
386-
{
387-
if( hi instanceof ReorgOp )
388-
{
384+
385+
private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) {
386+
if( hi instanceof ReorgOp ) {
389387
ReorgOp rop = (ReorgOp) hi;
390-
Hop input = hi.getInput(0);
388+
Hop input = hi.getInput(0);
391389
boolean apply = false;
392-
393-
//equal dims of reshape input and output -> no need for reshape because
390+
391+
//equal dims of reshape input and output -> no need for reshape because
394392
//byrow always refers to both input/output and hence gives the same result
395393
apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input));
396-
397-
//1x1 dimensions of transpose/reshape -> no need for reorg
398-
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE)
399-
&& rop.getDim1()==1 && rop.getDim2()==1);
400-
394+
395+
//1x1 dimensions of transpose/reshape/roll -> no need for reorg
396+
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE
397+
|| rop.getOp()==ReOrgOp.ROLL) && rop.getDim1()==1 && rop.getDim2()==1);
398+
401399
if( apply ) {
402400
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
403401
hi = input;
404402
LOG.debug("Applied removeUnnecessaryReorg.");
405403
}
406404
}
407-
405+
408406
return hi;
409407
}
410408

@@ -1356,44 +1354,78 @@ else if ( applyRight ) {
13561354
* @param pos position
13571355
* @return high-level operator
13581356
*/
1359-
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
1357+
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
13601358
{
13611359
//all patterns headed by full sum over binary operation
13621360
if( hi instanceof AggUnaryOp //full sum root over binaryop
1363-
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
1364-
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
1365-
&& hi.getInput(0) instanceof BinaryOp
1366-
&& hi.getInput(0).getParent().size()==1 ) //single parent
1361+
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
1362+
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
1363+
&& hi.getInput(0) instanceof BinaryOp
1364+
&& hi.getInput(0).getParent().size()==1 ) //single parent
13671365
{
13681366
BinaryOp bop = (BinaryOp) hi.getInput(0);
13691367
Hop left = bop.getInput(0);
13701368
Hop right = bop.getInput(1);
1371-
1372-
if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B)
1373-
&& left.getDataType() == DataType.MATRIX
1374-
&& right.getDataType() == DataType.MATRIX )
1369+
1370+
if( left.getDataType() == DataType.MATRIX
1371+
&& right.getDataType() == DataType.MATRIX )
13751372
{
13761373
OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
13771374
|| bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
13781375
? bop.getOp() : null;
1379-
1376+
13801377
if( applyOp != null ) {
1381-
//create new subdag sum(A) bop sum(B)
1382-
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1383-
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1384-
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
1385-
1386-
//rewire new subdag
1387-
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1388-
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1389-
1390-
hi = newBin;
1391-
1392-
LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
1378+
if (HopRewriteUtils.isEqualSize(left, right)) {
1379+
//create new subdag sum(A) bop sum(B) for equal-sized matrices
1380+
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1381+
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1382+
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
1383+
//rewire new subdag
1384+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1385+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1386+
1387+
hi = newBin;
1388+
1389+
LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
1390+
}
1391+
// Check if right operand is a vector (has dimension of 1 in either rows or columns)
1392+
else if (right.getDim1() == 1 || right.getDim2() == 1) {
1393+
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1394+
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1395+
1396+
// Row vector case (1 x n)
1397+
if (right.getDim1() == 1) {
1398+
// Create nrow(A) operation using dimensions
1399+
LiteralOp nRows = new LiteralOp(left.getDim1());
1400+
BinaryOp scaledSum = HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT);
1401+
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
1402+
//rewire new subdag
1403+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1404+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1405+
1406+
hi = newBin;
1407+
1408+
LOG.debug("Applied pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+").");
1409+
}
1410+
// Column vector case (n x 1)
1411+
else if (right.getDim2() == 1) {
1412+
// Create ncol(A) operation using dimensions
1413+
LiteralOp nCols = new LiteralOp(left.getDim2());
1414+
BinaryOp scaledSum = HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT);
1415+
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
1416+
//rewire new subdag
1417+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
1418+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
1419+
1420+
hi = newBin;
1421+
1422+
LOG.debug("Applied pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+").");
1423+
}
1424+
}
13931425
}
13941426
}
13951427
}
1396-
1428+
13971429
return hi;
13981430
}
13991431

src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
import org.junit.Assert;
2626
import org.junit.BeforeClass;
2727
import org.junit.Test;
28-
import org.apache.sysds.api.DMLScript;
2928
import org.apache.sysds.common.Types.ExecMode;
3029
import org.apache.sysds.common.Types.ExecType;
31-
import org.apache.sysds.runtime.instructions.Instruction;
3230
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
3331
import org.apache.sysds.test.AutomatedTestBase;
3432
import org.apache.sysds.test.TestConfiguration;
@@ -89,39 +87,24 @@ public void testPushDownSumMinusSP() {
8987
}
9088

9189
@Test
92-
public void testPushDownSumPlusNoRewriteSP() {
90+
public void testPushDownSumPlusBroadcastSP() {
9391
runPushdownSumOnBinaryTest(TEST_NAME1, false, ExecType.SPARK);
9492
}
9593

9694
@Test
97-
public void testPushDownSumMinusNoRewriteSP() {
95+
public void testPushDownSumMinusBroadcastSP() {
9896
runPushdownSumOnBinaryTest(TEST_NAME2, false, ExecType.SPARK);
9997
}
100-
101-
/**
102-
*
103-
* @param testname
104-
* @param type
105-
* @param sparse
106-
* @param instType
107-
*/
98+
10899
private void runPushdownSumOnBinaryTest( String testname, boolean equiDims, ExecType instType)
109100
{
110101
//rtplatform for MR
111-
ExecMode platformOld = rtplatform;
112-
switch( instType ){
113-
case SPARK: rtplatform = ExecMode.SPARK; break;
114-
default: rtplatform = ExecMode.HYBRID; break;
115-
}
116-
117-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
118-
if( rtplatform == ExecMode.SPARK )
119-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
102+
ExecMode platformOld = setExecMode(instType);
120103

121104
try
122105
{
123106
//determine script and function name
124-
String TEST_NAME = testname;
107+
String TEST_NAME = testname;
125108
String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME + "_" + String.valueOf(equiDims) + "/" : "";
126109

127110
TestConfiguration config = getTestConfiguration(TEST_NAME);
@@ -150,13 +133,10 @@ private void runPushdownSumOnBinaryTest( String testname, boolean equiDims, Exec
150133
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
151134

152135
String lopcode = TEST_NAME.equals(TEST_NAME1) ? "+" : "-";
153-
String opcode = equiDims ? lopcode : Instruction.SP_INST_PREFIX+"map"+lopcode;
154-
Assert.assertTrue("Non-applied rewrite", Statistics.getCPHeavyHitterOpCodes().contains(opcode));
136+
Assert.assertTrue("Non-applied rewrite", Statistics.getCPHeavyHitterOpCodes().contains(lopcode));
155137
}
156-
finally
157-
{
158-
rtplatform = platformOld;
159-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
138+
finally {
139+
resetExecMode(platformOld);
160140
}
161141
}
162142
}

src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ public void testPushdownSumBinaryMultRewrite2() {
6868
testRewritePushdownSumBinaryMult( TEST_NAME2, true );
6969
}
7070

71-
/**
72-
*
73-
* @param testname
74-
* @param rewrites
75-
*/
7671
private void testRewritePushdownSumBinaryMult( String testname, boolean rewrites )
7772
{
7873
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,54 +29,94 @@
2929
import org.apache.sysds.test.TestConfiguration;
3030
import org.apache.sysds.test.TestUtils;
3131

32-
public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
32+
public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
3333
{
3434
private static final String TEST_NAME1 = "RewritePushdownSumOnBinary";
3535
private static final String TEST_DIR = "functions/rewrite/";
3636
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumOnBinaryTest.class.getSimpleName() + "/";
37-
37+
3838
private static final int rows = 1000;
3939
private static final int cols = 1;
40-
40+
private static final double eps = 1e-8;
41+
4142
@Override
4243
public void setUp() {
4344
TestUtils.clearAssertionInformation();
44-
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R1", "R2" }) );
45+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,
46+
new String[] { "R1", "R2", "R3", "R4" }));
47+
}
48+
49+
@Test
50+
public void testRewritePushdownSumOnBinaryNoRewrite() {
51+
testRewritePushdownSumOnBinary(TEST_NAME1, false);
52+
}
53+
54+
@Test
55+
public void testRewritePushdownSumOnBinary() {
56+
testRewritePushdownSumOnBinary(TEST_NAME1, true);
4557
}
4658

4759
@Test
48-
public void testRewritePushdownSumOnBinaryNoRewrite() {
49-
testRewritePushdownSumOnBinary( TEST_NAME1, false );
60+
public void testRewritePushdownSumOnBinaryRowVector() {
61+
testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, true);
5062
}
51-
63+
5264
@Test
53-
public void testRewritePushdownSumOnBinary() {
54-
testRewritePushdownSumOnBinary( TEST_NAME1, true );
65+
public void testRewritePushdownSumOnBinaryColVector() {
66+
testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, false);
5567
}
56-
57-
private void testRewritePushdownSumOnBinary( String testname, boolean rewrites )
58-
{
68+
69+
private void testRewritePushdownSumOnBinary(String testname, boolean rewrites) {
5970
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
60-
71+
6172
try {
6273
TestConfiguration config = getTestConfiguration(testname);
6374
loadTestConfiguration(config);
64-
75+
6576
String HOME = SCRIPT_DIR + TEST_DIR;
6677
fullDMLScriptName = HOME + testname + ".dml";
67-
programArgs = new String[]{ "-args", String.valueOf(rows),
68-
String.valueOf(cols), output("R1"), output("R2") };
78+
79+
programArgs = new String[]{ "-args", String.valueOf(rows),
80+
String.valueOf(cols), output("R1"), output("R2"),
81+
String.valueOf(rows), String.valueOf(cols) }; // Assuming row and col vectors
82+
6983
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
7084

71-
//run performance tests
85+
// Run performance tests
7286
runTest(true, false, null, -1);
73-
74-
//compare matrices
75-
long expect = Math.round(0.5*rows);
87+
88+
// Compare matrices
89+
long expect = Math.round(0.5 * rows);
7690
HashMap<CellIndex, Double> dmlfile1 = readDMLScalarFromOutputDir("R1");
77-
Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1,1)), expect*0.01);
91+
Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1, 1)), eps);
7892
HashMap<CellIndex, Double> dmlfile2 = readDMLScalarFromOutputDir("R2");
79-
Assert.assertEquals(expect, dmlfile2.get(new CellIndex(1,1)), expect*0.01);
93+
Assert.assertEquals(expect, dmlfile2.get(new CellIndex(1, 1)), eps);
94+
} finally {
95+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
96+
}
97+
}
98+
99+
100+
private void testRewritePushdownSumOnBinaryVector(String testname, boolean rewrites, boolean isRow) {
101+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
102+
try {
103+
TestConfiguration config = getTestConfiguration(testname);
104+
loadTestConfiguration(config);
105+
106+
String HOME = SCRIPT_DIR + TEST_DIR;
107+
fullDMLScriptName = HOME + testname + ".dml";
108+
programArgs = new String[]{ "-args", String.valueOf(rows),
109+
String.valueOf(cols), output("R3"), output("R4"),
110+
String.valueOf(isRow ? 1 : rows), String.valueOf(isRow ? cols : 1) };
111+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
112+
113+
runTest(true, false, null, -1);
114+
115+
long expect = Math.round(500); // Expected value for 0.5 + 0.5
116+
HashMap<CellIndex, Double> dmlfile3 = readDMLScalarFromOutputDir("R3");
117+
Assert.assertEquals(expect, dmlfile3.get(new CellIndex(1,1)), eps);
118+
HashMap<CellIndex, Double> dmlfile4 = readDMLScalarFromOutputDir("R4");
119+
Assert.assertEquals(expect, dmlfile4.get(new CellIndex(1,1)), eps);
80120
}
81121
finally {
82122
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;

src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,30 @@
1919
#
2020
#-------------------------------------------------------------
2121

22-
A = rand(rows=$1, cols=$2, seed=1);
23-
B = rand(rows=$1, cols=$2, seed=2);
24-
C = rand(rows=$1, cols=$2, seed=3);
25-
D = rand(rows=$1, cols=$2, seed=4);
22+
# Required parameters
23+
A = matrix(0.5, rows=$1, cols=$2);
24+
B = matrix(0.5, rows=$1, cols=$2);
25+
C = matrix(0.5, rows=$1, cols=$2);
26+
D = matrix(0.5, rows=$1, cols=$2);
2627

28+
# Set defaults for optional parameters
29+
rowsV = ifdef($5, 0)
30+
colsV = ifdef($6, 0)
31+
32+
# Original matrix tests
2733
r1 = sum(A*B + C*D);
2834
r2 = r1;
2935

30-
print("r1="+r1+", r2="+r2);
36+
# Vector tests
37+
if (rowsV != 0 & colsV != 0) {
38+
V = matrix(0.5, rows=rowsV, cols=colsV);
39+
r3 = sum(A + V);
40+
r4 = r3;
41+
}
42+
3143
write(r1, $3);
3244
write(r2, $4);
33-
45+
if (rowsV != 0 & colsV != 0) {
46+
write(r3, $5);
47+
write(r4, $6);
48+
}

0 commit comments

Comments
 (0)