@@ -88,6 +88,9 @@ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewr
8888 //e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
8989 sb = vectorizeScalarAggregate (sb , csb , from , to , incr , iterVar );
9090
91+ //e.g., for(i in a:b){s = s + X[i,2]} -> s = sum(X[a:b,2])
92+ sb = vectorizeScalarAggregate2 (sb , csb , from , to , incr , iterVar );
93+
9194 //e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
9295 sb = vectorizeElementwiseBinary (sb , csb , from , to , incr , iterVar );
9396
@@ -205,6 +208,80 @@ else if( ix.isColLowerEqualsUpper() && ix.getInput(3) instanceof DataOp
205208 return ret ;
206209 }
207210
211+ private static StatementBlock vectorizeScalarAggregate2 ( StatementBlock sb , StatementBlock csb , Hop from , Hop to , Hop increment , String itervar )
212+ {
213+ StatementBlock ret = sb ;
214+
215+ //check for applicability
216+ boolean leftScalar = false ;
217+ boolean rightScalar = false ;
218+ boolean rowIx = false ; //row or col
219+
220+ if ( csb .getHops ()!=null && csb .getHops ().size ()==1 ) {
221+ Hop root = csb .getHops ().get (0 );
222+
223+ if ( root .getDataType ()==DataType .SCALAR && root .getInput (0 ) instanceof BinaryOp ) {
224+ BinaryOp bop = (BinaryOp ) root .getInput (0 );
225+ Hop left = bop .getInput (0 );
226+ Hop right = bop .getInput (1 );
227+
228+ //check for left scalar plus
229+ if ( HopRewriteUtils .isValidOp (bop .getOp (), MAP_SCALAR_AGGREGATE_SOURCE_OPS )
230+ && left instanceof DataOp && left .getDataType () == DataType .SCALAR
231+ && root .getName ().equals (left .getName ())
232+ && right instanceof IndexingOp && right .isScalar ())
233+ {
234+ leftScalar = true ;
235+ rowIx = true ; //row and col
236+ }
237+ //check for right scalar plus
238+ else if ( HopRewriteUtils .isValidOp (bop .getOp (), MAP_SCALAR_AGGREGATE_SOURCE_OPS )
239+ && right instanceof DataOp && right .getDataType () == DataType .SCALAR
240+ && root .getName ().equals (right .getName ())
241+ && left instanceof IndexingOp && left .isScalar ())
242+ {
243+ rightScalar = true ;
244+ rowIx = true ; //row and col
245+ }
246+ }
247+ }
248+
249+ //apply rewrite if possible
250+ if ( leftScalar || rightScalar ) {
251+ Hop root = csb .getHops ().get (0 );
252+ BinaryOp bop = (BinaryOp ) root .getInput (0 );
253+ Hop ix = bop .getInput ().get ( leftScalar ?1 :0 );
254+ int aggOpPos = HopRewriteUtils .getValidOpPos (bop .getOp (), MAP_SCALAR_AGGREGATE_SOURCE_OPS );
255+ AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS [aggOpPos ];
256+
257+ //replace cast with sum
258+ AggUnaryOp newSum = HopRewriteUtils .createAggUnaryOp (ix , aggOp , Direction .RowCol );
259+ HopRewriteUtils .removeChildReference (bop , ix );
260+ HopRewriteUtils .addChildReference (bop , newSum , leftScalar ?1 :0 );
261+
262+ //modify indexing expression according to loop predicate from-to
263+ //NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
264+ int index1 = rowIx ? 1 : 3 ;
265+ int index2 = rowIx ? 2 : 4 ;
266+ HopRewriteUtils .replaceChildReference (ix , ix .getInput ().get (index1 ), from , index1 );
267+ HopRewriteUtils .replaceChildReference (ix , ix .getInput ().get (index2 ), to , index2 );
268+
269+ //update indexing size information
270+ if ( rowIx )
271+ ((IndexingOp )ix ).setRowLowerEqualsUpper (false );
272+ else
273+ ((IndexingOp )ix ).setColLowerEqualsUpper (false );
274+ ix .setDataType (DataType .MATRIX );
275+ ix .refreshSizeInformation ();
276+ Hop .resetVisitStatus (csb .getHops (), true );
277+
278+ ret = csb ;
279+ LOG .debug ("Applied vectorizeScalarSumForLoop2." );
280+ }
281+
282+ return ret ;
283+ }
284+
208285 private static StatementBlock vectorizeElementwiseBinary ( StatementBlock sb , StatementBlock csb , Hop from , Hop to , Hop increment , String itervar )
209286 {
210287 StatementBlock ret = sb ;
0 commit comments