@@ -381,30 +381,28 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)
381381
382382 return hi ;
383383 }
384-
385- private static Hop removeUnnecessaryReorgOperation (Hop parent , Hop hi , int pos )
386- {
387- if ( hi instanceof ReorgOp )
388- {
384+
385+ private static Hop removeUnnecessaryReorgOperation (Hop parent , Hop hi , int pos ) {
386+ if ( hi instanceof ReorgOp ) {
389387 ReorgOp rop = (ReorgOp ) hi ;
390- Hop input = hi .getInput (0 );
388+ Hop input = hi .getInput (0 );
391389 boolean apply = false ;
392-
393- //equal dims of reshape input and output -> no need for reshape because
390+
391+ //equal dims of reshape input and output -> no need for reshape because
394392 //byrow always refers to both input/output and hence gives the same result
395393 apply |= (rop .getOp ()==ReOrgOp .RESHAPE && HopRewriteUtils .isEqualSize (hi , input ));
396-
397- //1x1 dimensions of transpose/reshape -> no need for reorg
398- apply |= ((rop .getOp ()==ReOrgOp .TRANS || rop .getOp ()==ReOrgOp .RESHAPE )
399- && rop .getDim1 ()==1 && rop .getDim2 ()==1 );
400-
394+
395+ //1x1 dimensions of transpose/reshape/roll -> no need for reorg
396+ apply |= ((rop .getOp ()==ReOrgOp .TRANS || rop .getOp ()==ReOrgOp .RESHAPE
397+ || rop . getOp ()== ReOrgOp . ROLL ) && rop .getDim1 ()==1 && rop .getDim2 ()==1 );
398+
401399 if ( apply ) {
402400 HopRewriteUtils .replaceChildReference (parent , hi , input , pos );
403401 hi = input ;
404402 LOG .debug ("Applied removeUnnecessaryReorg." );
405403 }
406404 }
407-
405+
408406 return hi ;
409407 }
410408
@@ -1356,44 +1354,78 @@ else if ( applyRight ) {
13561354 * @param pos position
13571355 * @return high-level operator
13581356 */
1359- private static Hop pushdownSumOnAdditiveBinary (Hop parent , Hop hi , int pos )
1357+ private static Hop pushdownSumOnAdditiveBinary (Hop parent , Hop hi , int pos )
13601358 {
13611359 //all patterns headed by full sum over binary operation
13621360 if ( hi instanceof AggUnaryOp //full sum root over binaryop
1363- && ((AggUnaryOp )hi ).getDirection ()==Direction .RowCol
1364- && ((AggUnaryOp )hi ).getOp () == AggOp .SUM
1365- && hi .getInput (0 ) instanceof BinaryOp
1366- && hi .getInput (0 ).getParent ().size ()==1 ) //single parent
1361+ && ((AggUnaryOp )hi ).getDirection ()==Direction .RowCol
1362+ && ((AggUnaryOp )hi ).getOp () == AggOp .SUM
1363+ && hi .getInput (0 ) instanceof BinaryOp
1364+ && hi .getInput (0 ).getParent ().size ()==1 ) //single parent
13671365 {
13681366 BinaryOp bop = (BinaryOp ) hi .getInput (0 );
13691367 Hop left = bop .getInput (0 );
13701368 Hop right = bop .getInput (1 );
1371-
1372- if ( HopRewriteUtils .isEqualSize (left , right ) //dims(A) == dims(B)
1373- && left .getDataType () == DataType .MATRIX
1374- && right .getDataType () == DataType .MATRIX )
1369+
1370+ if ( left .getDataType () == DataType .MATRIX
1371+ && right .getDataType () == DataType .MATRIX )
13751372 {
13761373 OpOp2 applyOp = ( bop .getOp () == OpOp2 .PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
13771374 || bop .getOp () == OpOp2 .MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
13781375 ? bop .getOp () : null ;
1379-
1376+
13801377 if ( applyOp != null ) {
1381- //create new subdag sum(A) bop sum(B)
1382- AggUnaryOp sum1 = HopRewriteUtils .createSum (left );
1383- AggUnaryOp sum2 = HopRewriteUtils .createSum (right );
1384- BinaryOp newBin = HopRewriteUtils .createBinary (sum1 , sum2 , applyOp );
1385-
1386- //rewire new subdag
1387- HopRewriteUtils .replaceChildReference (parent , hi , newBin , pos );
1388- HopRewriteUtils .cleanupUnreferenced (hi , bop );
1389-
1390- hi = newBin ;
1391-
1392- LOG .debug ("Applied pushdownSumOnAdditiveBinary (line " +hi .getBeginLine ()+")." );
1378+ if (HopRewriteUtils .isEqualSize (left , right )) {
1379+ //create new subdag sum(A) bop sum(B) for equal-sized matrices
1380+ AggUnaryOp sum1 = HopRewriteUtils .createSum (left );
1381+ AggUnaryOp sum2 = HopRewriteUtils .createSum (right );
1382+ BinaryOp newBin = HopRewriteUtils .createBinary (sum1 , sum2 , applyOp );
1383+ //rewire new subdag
1384+ HopRewriteUtils .replaceChildReference (parent , hi , newBin , pos );
1385+ HopRewriteUtils .cleanupUnreferenced (hi , bop );
1386+
1387+ hi = newBin ;
1388+
1389+ LOG .debug ("Applied pushdownSumOnAdditiveBinary (line " +hi .getBeginLine ()+")." );
1390+ }
1391+ // Check if right operand is a vector (has dimension of 1 in either rows or columns)
1392+ else if (right .getDim1 () == 1 || right .getDim2 () == 1 ) {
1393+ AggUnaryOp sum1 = HopRewriteUtils .createSum (left );
1394+ AggUnaryOp sum2 = HopRewriteUtils .createSum (right );
1395+
1396+ // Row vector case (1 x n)
1397+ if (right .getDim1 () == 1 ) {
1398+ // Create nrow(A) operation using dimensions
1399+ UnaryOp nRows = HopRewriteUtils .createUnary (left , OpOp1 .NROW );
1400+ BinaryOp scaledSum = HopRewriteUtils .createBinary (nRows , sum2 , OpOp2 .MULT );
1401+ BinaryOp newBin = HopRewriteUtils .createBinary (sum1 , scaledSum , applyOp );
1402+ //rewire new subdag
1403+ HopRewriteUtils .replaceChildReference (parent , hi , newBin , pos );
1404+ HopRewriteUtils .cleanupUnreferenced (hi , bop );
1405+
1406+ hi = newBin ;
1407+
1408+ LOG .debug ("Applied pushdownSumOnAdditiveBinary with row vector (line " +hi .getBeginLine ()+")." );
1409+ }
1410+ // Column vector case (n x 1)
1411+ else if (right .getDim2 () == 1 ) {
1412+ // Create ncol(A) operation using dimensions
1413+ UnaryOp nCols = HopRewriteUtils .createUnary (left , OpOp1 .NCOL );
1414+ BinaryOp scaledSum = HopRewriteUtils .createBinary (nCols , sum2 , OpOp2 .MULT );
1415+ BinaryOp newBin = HopRewriteUtils .createBinary (sum1 , scaledSum , applyOp );
1416+ //rewire new subdag
1417+ HopRewriteUtils .replaceChildReference (parent , hi , newBin , pos );
1418+ HopRewriteUtils .cleanupUnreferenced (hi , bop );
1419+
1420+ hi = newBin ;
1421+
1422+ LOG .debug ("Applied pushdownSumOnAdditiveBinary with column vector (line " +hi .getBeginLine ()+")." );
1423+ }
1424+ }
13931425 }
13941426 }
13951427 }
1396-
1428+
13971429 return hi ;
13981430 }
13991431
0 commit comments