Skip to content

Commit cef8a6f

Browse files
committed
[SYSTEMDS-3798,3807] Improved loop vectorization rewrite, code coverage
1 parent 0de6fab commit cef8a6f

File tree

4 files changed

+79
-4
lines changed

4 files changed

+79
-4
lines changed

.github/workflows/javaTests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ jobs:
8686
"**.functions.transform.**","**.functions.unique.**",
8787
"**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**"
8888
]
89-
java: [11]
89+
java: ['11']
90+
javadist: ['adopt']
9091
name: ${{ matrix.tests }}
9192
steps:
9293
- name: Checkout Repository

src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewr
8888
//e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
8989
sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
9090

91+
//e.g., for(i in a:b){s = s + X[i,2]} -> s = sum(X[a:b,2])
92+
sb = vectorizeScalarAggregate2(sb, csb, from, to, incr, iterVar);
93+
9194
//e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
9295
sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
9396

@@ -205,6 +208,80 @@ else if( ix.isColLowerEqualsUpper() && ix.getInput(3) instanceof DataOp
205208
return ret;
206209
}
207210

211+
private static StatementBlock vectorizeScalarAggregate2( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
212+
{
213+
StatementBlock ret = sb;
214+
215+
//check for applicability
216+
boolean leftScalar = false;
217+
boolean rightScalar = false;
218+
boolean rowIx = false; //row or col
219+
220+
if( csb.getHops()!=null && csb.getHops().size()==1 ) {
221+
Hop root = csb.getHops().get(0);
222+
223+
if( root.getDataType()==DataType.SCALAR && root.getInput(0) instanceof BinaryOp ) {
224+
BinaryOp bop = (BinaryOp) root.getInput(0);
225+
Hop left = bop.getInput(0);
226+
Hop right = bop.getInput(1);
227+
228+
//check for left scalar plus
229+
if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
230+
&& left instanceof DataOp && left.getDataType() == DataType.SCALAR
231+
&& root.getName().equals(left.getName())
232+
&& right instanceof IndexingOp && right.isScalar())
233+
{
234+
leftScalar = true;
235+
rowIx = true; //row and col
236+
}
237+
//check for right scalar plus
238+
else if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
239+
&& right instanceof DataOp && right.getDataType() == DataType.SCALAR
240+
&& root.getName().equals(right.getName())
241+
&& left instanceof IndexingOp && left.isScalar())
242+
{
243+
rightScalar = true;
244+
rowIx = true; //row and col
245+
}
246+
}
247+
}
248+
249+
//apply rewrite if possible
250+
if( leftScalar || rightScalar ) {
251+
Hop root = csb.getHops().get(0);
252+
BinaryOp bop = (BinaryOp) root.getInput(0);
253+
Hop ix = bop.getInput().get( leftScalar?1:0 );
254+
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
255+
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
256+
257+
//replace cast with sum
258+
AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
259+
HopRewriteUtils.removeChildReference(bop, ix);
260+
HopRewriteUtils.addChildReference(bop, newSum, leftScalar?1:0 );
261+
262+
//modify indexing expression according to loop predicate from-to
263+
//NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
264+
int index1 = rowIx ? 1 : 3;
265+
int index2 = rowIx ? 2 : 4;
266+
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
267+
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
268+
269+
//update indexing size information
270+
if( rowIx )
271+
((IndexingOp)ix).setRowLowerEqualsUpper(false);
272+
else
273+
((IndexingOp)ix).setColLowerEqualsUpper(false);
274+
ix.setDataType(DataType.MATRIX);
275+
ix.refreshSizeInformation();
276+
Hop.resetVisitStatus(csb.getHops(), true);
277+
278+
ret = csb;
279+
LOG.debug("Applied vectorizeScalarSumForLoop2.");
280+
}
281+
282+
return ret;
283+
}
284+
208285
private static StatementBlock vectorizeElementwiseBinary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
209286
{
210287
StatementBlock ret = sb;

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import org.junit.Assert;
2525
import org.junit.Test;
26-
import org.apache.sysds.api.DMLScript;
2726
import org.apache.sysds.common.Types.ExecMode;
2827
import org.apache.sysds.hops.OptimizerUtils;
2928
import org.apache.sysds.common.Types.ExecType;

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.HashMap;
2323

2424
import org.junit.Assert;
25-
import org.junit.Ignore;
2625
import org.junit.Test;
2726
import org.apache.sysds.hops.OptimizerUtils;
2827
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -58,7 +57,6 @@ public void testLoopVectorizationSumNoRewrite() {
5857
}
5958

6059
@Test
61-
@Ignore //FIXME: extend loop vectorization rewrite
6260
public void testLoopVectorizationSumRewrite() {
6361
testRewriteLoopVectorizationSum( TEST_NAME1, true );
6462
}

0 commit comments

Comments
 (0)