@@ -200,7 +200,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
200200 hi = simplifyCumsumColOrFullAggregates (hi ); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
201201 hi = simplifyCumsumReverse (hop , hi , i ); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
202202 hi = simplifyNegatedSubtraction (hop , hi , i ); //e.g., -(B-A)->A-B
203- hi = simplifyTransposeAddition (hop , hi , i ); //e.g., t(A+1)+2 -> t(A)+1+2 -> t(A)+3
203+ hi = simplifyTransposeAddition (hop , hi , i ); //e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding
204204 hi = simplifyNotOverComparisons (hop , hi , i ); //e.g., !(A>B) -> (A<=B)
205205 //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
206206
@@ -213,95 +213,46 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
213213 }
214214
215215 private static Hop simplifyTransposeAddition (Hop parent , Hop hi , int pos ) {
216- if (!(hi instanceof BinaryOp )
217- || ((BinaryOp )hi ).getOp () != OpOp2 .PLUS
218- || hi .getDataType () != DataType .MATRIX )
219- return hi ;
220-
221- BinaryOp bop = (BinaryOp )hi ;
222-
223- ReorgOp tSide = null ;
224- LiteralOp litSide = null ;
225- Hop in0 = bop .getInput ().get (0 ), in1 = bop .getInput ().get (1 );
226- if (in0 instanceof ReorgOp && ((ReorgOp )in0 ).getOp () == ReOrgOp .TRANS
227- && in1 instanceof LiteralOp ) {
228- tSide = (ReorgOp )in0 ;
229- litSide = (LiteralOp )in1 ;
230- }
231- else if (in1 instanceof ReorgOp && ((ReorgOp )in1 ).getOp () == ReOrgOp .TRANS
232- && in0 instanceof LiteralOp ) {
233- tSide = (ReorgOp )in1 ;
234- litSide = (LiteralOp )in0 ;
235- }
236- else
237- return hi ;
238-
239- //check if only consumer
240- if (tSide .getParent ().size () > 1 ) {
241- return hi ;
242- }
243-
244- Hop inner = tSide .getInput ().get (0 );
245- if (!(inner instanceof BinaryOp )
246- || ((BinaryOp )inner ).getOp () != OpOp2 .PLUS
247- || inner .getDataType () != DataType .MATRIX )
248- return hi ;
249-
250- BinaryOp ib = (BinaryOp )inner ;
251-
252- Hop X = null ;
253- LiteralOp lit1 = null ;
254- Hop i0 = ib .getInput ().get (0 ), i1 = ib .getInput ().get (1 );
255- if (i0 instanceof LiteralOp ) {
256- lit1 = (LiteralOp )i0 ;
257- X = i1 ;
258- }
259- else if (i1 instanceof LiteralOp ) {
260- lit1 = (LiteralOp )i1 ;
261- X = i0 ;
216+ //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding
217+ if (HopRewriteUtils .isBinary (hi , OpOp2 .PLUS )
218+ && hi .isMatrix () && hi .getInput (1 ).isScalar ()
219+ && HopRewriteUtils .isReorg (hi .getInput (0 ), ReOrgOp .TRANS )
220+ && hi .getInput (0 ).getParent ().size () == 1
221+ && HopRewriteUtils .isBinary (hi .getInput (0 ).getInput (0 ), OpOp2 .PLUS )
222+ && hi .getInput (0 ).getInput (0 ).getParent ().size () == 1
223+ && (hi .getInput (0 ).getInput (0 ).getInput (0 ).isScalar ()
224+ || hi .getInput (0 ).getInput (0 ).getInput (1 ).isScalar ()))
225+ {
226+ int six = hi .getInput (0 ).getInput (0 ).getInput (0 ).isScalar () ? 0 : 1 ;
227+ Hop A = hi .getInput (0 ).getInput (0 ).getInput (six ==0 ? 1 : 0 );
228+ Hop s1 = hi .getInput (0 ).getInput (0 ).getInput (six );
229+ Hop s2 = hi .getInput (1 );
230+
231+ Hop tA = HopRewriteUtils .createTranspose (A );
232+ Hop s12 = HopRewriteUtils .createBinary (s1 , s2 , OpOp2 .PLUS );
233+ Hop newHop = HopRewriteUtils .createBinary (tA , s12 , OpOp2 .PLUS );
234+
235+ HopRewriteUtils .replaceChildReference (parent , hi , newHop , pos );
236+ HopRewriteUtils .cleanupUnreferenced (hi );
237+ hi = newHop ;
238+
239+ LOG .debug ("Applied simplifyTransposeAddition (line " + hi .getBeginLine () + ")." );
262240 }
263- else
264- return hi ;
265-
266- double c = lit1 .getDoubleValue () + litSide .getDoubleValue ();
267-
268- ReorgOp newT = HopRewriteUtils .createTranspose (X );
269- newT .setDim1 (tSide .getDim1 ());
270- newT .setDim2 (tSide .getDim2 ());
271-
272- LiteralOp newLit = new LiteralOp (c );
273- newLit .setDim1 (1 );
274- newLit .setDim2 (1 );
275-
276- //creating new binaryOp
277- BinaryOp newPlus = HopRewriteUtils .createBinary (newT , newLit , OpOp2 .PLUS );
278- newPlus .setDim1 (bop .getDim1 ());
279- newPlus .setDim2 (bop .getDim2 ());
280-
281- HopRewriteUtils .replaceChildReference (parent , bop , newPlus , pos );
282- HopRewriteUtils .cleanupUnreferenced (bop , tSide , ib , litSide );
283-
284- LOG .debug ("Applied simplifyTransposeAddition (line " + hi .getBeginLine () + ")." );
285-
286- return newPlus ;
241+
242+ return hi ;
287243 }
288244
289245 private static Hop simplifyNegatedSubtraction (Hop parent , Hop hi , int pos ) {
290- if (hi instanceof BinaryOp
291- && ((BinaryOp ) hi ).getOp () == OpOp2 .MINUS
292- && HopRewriteUtils .isLiteralOfValue (hi .getInput ().get (0 ), 0 )
293- && hi .getParent ().size () == 1
294- && hi .getInput ().get (1 ) instanceof BinaryOp
295- && ((BinaryOp ) hi .getInput ().get (1 )).getOp () == OpOp2 .MINUS
296- && hi .getInput ().get (1 ).getParent ().size () == 1 )
246+ //pattern: -(B-A)->A-B, but only of (B-A) consumed once
247+ if (HopRewriteUtils .isBinary (hi , OpOp2 .MINUS )
248+ && HopRewriteUtils .isLiteralOfValue (hi .getInput (0 ), 0 )
249+ && HopRewriteUtils .isBinary (hi .getInput (1 ), OpOp2 .MINUS )
250+ && hi .getInput ().get (1 ).getParent ().size () == 1 )
297251 {
298- Hop innerMinus = hi .getInput ().get (1 );
299- Hop B = innerMinus .getInput ().get (0 );
300- Hop A = innerMinus .getInput ().get (1 );
252+ Hop B = hi .getInput (1 ).getInput (0 );
253+ Hop A = hi .getInput (1 ).getInput (1 );
301254
302255 BinaryOp newHop = HopRewriteUtils .createBinary (A , B , OpOp2 .MINUS );
303-
304- HopRewriteUtils .copyLineNumbers (hi , newHop );
305256 HopRewriteUtils .replaceChildReference (parent , hi , newHop , pos );
306257 HopRewriteUtils .cleanupUnreferenced (hi );
307258 hi = newHop ;
0 commit comments