Skip to content

Commit 091144d

Browse files
committed
[SYSTEMDS-3798] Fix generality of loop vectorization rewrite
1 parent cef8a6f commit 091144d

File tree

2 files changed

+35
-27
lines changed

2 files changed

+35
-27
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,11 @@ private static StatementBlock vectorizeScalarAggregate( StatementBlock sb, State
138138
&& right.getInput(0) instanceof IndexingOp )
139139
{
140140
IndexingOp ix = (IndexingOp)right.getInput(0);
141-
if( ix.isRowLowerEqualsUpper() && ix.getInput(1) instanceof DataOp
142-
&& ix.getInput(1).getName().equals(itervar) ){
141+
if( checkItervarIndexing(ix, itervar, true) ){
143142
leftScalar = true;
144143
rowIx = true;
145144
}
146-
else if( ix.isColLowerEqualsUpper() && ix.getInput(3) instanceof DataOp
147-
&& ix.getInput(3).getName().equals(itervar) ){
145+
else if( checkItervarIndexing(ix, itervar, false) ){
148146
leftScalar = true;
149147
rowIx = false;
150148
}
@@ -157,13 +155,11 @@ else if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
157155
&& left.getInput(0) instanceof IndexingOp )
158156
{
159157
IndexingOp ix = (IndexingOp)left.getInput(0);
160-
if( ix.isRowLowerEqualsUpper() && ix.getInput(1) instanceof DataOp
161-
&& ix.getInput(1).getName().equals(itervar) ){
158+
if( checkItervarIndexing(ix, itervar, true) ){
162159
rightScalar = true;
163160
rowIx = true;
164161
}
165-
else if( ix.isColLowerEqualsUpper() && ix.getInput(3) instanceof DataOp
166-
&& ix.getInput(3).getName().equals(itervar) ){
162+
else if( checkItervarIndexing(ix, itervar, false) ){
167163
rightScalar = true;
168164
rowIx = false;
169165
}
@@ -231,17 +227,29 @@ private static StatementBlock vectorizeScalarAggregate2( StatementBlock sb, Stat
231227
&& root.getName().equals(left.getName())
232228
&& right instanceof IndexingOp && right.isScalar())
233229
{
234-
leftScalar = true;
235-
rowIx = true; //row and col
230+
if( checkItervarIndexing((IndexingOp)right, itervar, true) ){
231+
leftScalar = true;
232+
rowIx = true;
233+
}
234+
else if( checkItervarIndexing((IndexingOp)right, itervar, false) ){
235+
leftScalar = true;
236+
rowIx = false;
237+
}
236238
}
237239
//check for right scalar plus
238240
else if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
239241
&& right instanceof DataOp && right.getDataType() == DataType.SCALAR
240242
&& root.getName().equals(right.getName())
241243
&& left instanceof IndexingOp && left.isScalar())
242244
{
243-
rightScalar = true;
244-
rowIx = true; //row and col
245+
if( checkItervarIndexing((IndexingOp)left, itervar, true) ){
246+
rightScalar = true;
247+
rowIx = true;
248+
}
249+
else if( checkItervarIndexing((IndexingOp)left, itervar, false) ){
250+
rightScalar = true;
251+
rowIx = false;
252+
}
245253
}
246254
}
247255
}
@@ -461,6 +469,12 @@ private static StatementBlock vectorizeIndexedCopy( StatementBlock sb, Statement
461469
return ret;
462470
}
463471

472+
private static boolean checkItervarIndexing(IndexingOp ix, String itervar, boolean row) {
473+
return ix.isRowLowerEqualsUpper()
474+
&& ix.getInput(row?1:3) instanceof DataOp
475+
&& ix.getInput(row?1:3).getName().equals(itervar);
476+
}
477+
464478
private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp lix, IndexingOp rix, String itervar) {
465479
boolean[] ret = new boolean[2]; //apply, rowIx
466480

src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -213,42 +213,36 @@ public void testVectorizeForLoopBinaryColNeg() {
213213
runVectorizationTest( TEST_NAME24 );
214214
}
215215

216-
/**
217-
*
218-
* @param cfc
219-
* @param vt
220-
*/
221216
private void runVectorizationTest( String testName )
222217
{
223218
String TEST_NAME = testName;
224219

225220
try
226-
{
221+
{
227222
TestConfiguration config = getTestConfiguration(TEST_NAME);
228223
loadTestConfiguration(config);
229224

230-
String HOME = SCRIPT_DIR + TEST_DIR;
225+
String HOME = SCRIPT_DIR + TEST_DIR;
231226
fullDMLScriptName = HOME + TEST_NAME + ".dml";
232227
programArgs = new String[]{"-explain","-args", input("A"), output("R") };
233228

234229
fullRScriptName = HOME + TEST_NAME + ".R";
235-
rCmd = getRCmd(inputDir(), expectedDir());
230+
rCmd = getRCmd(inputDir(), expectedDir());
236231

237232
//generate input
238233
double[][] A = getRandomMatrix(rows, cols, 0, 1, 1.0, 7);
239234
writeInputMatrixWithMTD("A", A, true);
240235

241236
//run tests
242-
runTest(true, false, null, -1);
243-
runRScript(true);
244-
245-
//compare results
246-
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
237+
runTest(true, false, null, -1);
238+
runRScript(true);
239+
240+
//compare results
241+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
247242
HashMap<CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
248243
TestUtils.compareMatrices(dmlfile, rfile, 1e-14, "DML", "R");
249244
}
250-
catch(Exception ex)
251-
{
245+
catch(Exception ex) {
252246
throw new RuntimeException(ex);
253247
}
254248
}

0 commit comments

Comments
 (0)