@@ -307,6 +307,71 @@ mlir::daphne::RandMatrixOp pushDownRandomIntoEwAdd(mlir::daphne::RandMatrixOp ra
307307 return rewriter.create <mlir::daphne::RandMatrixOp>(op.getLoc (), op.getResult ().getType (), width, height, newMin,
308308 newMax, sparsity, seed);
309309}
310+ mlir::daphne::RandMatrixOp pushDownRandomIntoEwSub (mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwSubOp op,
311+ mlir::Value scalar, mlir::PatternRewriter &rewriter) {
312+ auto max = randOp.getMax ();
313+ auto min = randOp.getMin ();
314+ auto height = randOp.getNumRows ();
315+ auto width = randOp.getNumCols ();
316+ auto sparsity = randOp.getSparsity ();
317+ auto seed = randOp.getSeed ();
318+ mlir::daphne::EwSubOp newMax = rewriter.create <mlir::daphne::EwSubOp>(op.getLoc (), max, scalar);
319+ mlir::daphne::EwSubOp newMin = rewriter.create <mlir::daphne::EwSubOp>(op.getLoc (), min, scalar);
320+ return rewriter.create <mlir::daphne::RandMatrixOp>(op.getLoc (), op.getResult ().getType (), width, height, newMin,
321+ newMax, sparsity, seed);
322+ }
323+ mlir::daphne::RandMatrixOp pushDownRandomIntoEwMul (mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwMulOp op,
324+ mlir::Value scalar, mlir::PatternRewriter &rewriter) {
325+ auto max = randOp.getMax ();
326+ auto min = randOp.getMin ();
327+ auto height = randOp.getNumRows ();
328+ auto width = randOp.getNumCols ();
329+ auto sparsity = randOp.getSparsity ();
330+ auto seed = randOp.getSeed ();
331+ mlir::daphne::EwMulOp newMax = rewriter.create <mlir::daphne::EwMulOp>(op.getLoc (), max, scalar);
332+ mlir::daphne::EwMulOp newMin = rewriter.create <mlir::daphne::EwMulOp>(op.getLoc (), min, scalar);
333+ return rewriter.create <mlir::daphne::RandMatrixOp>(op.getLoc (), op.getResult ().getType (), width, height, newMin,
334+ newMax, sparsity, seed);
335+ }
336+ mlir::daphne::RandMatrixOp pushDownRandomIntoEwDiv (mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwDivOp op,
337+ mlir::Value scalar, mlir::PatternRewriter &rewriter) {
338+ auto max = randOp.getMax ();
339+ auto min = randOp.getMin ();
340+ auto height = randOp.getNumRows ();
341+ auto width = randOp.getNumCols ();
342+ auto sparsity = randOp.getSparsity ();
343+ auto seed = randOp.getSeed ();
344+ mlir::daphne::EwDivOp newMax = rewriter.create <mlir::daphne::EwDivOp>(op.getLoc (), max, scalar);
345+ mlir::daphne::EwDivOp newMin = rewriter.create <mlir::daphne::EwDivOp>(op.getLoc (), min, scalar);
346+ return rewriter.create <mlir::daphne::RandMatrixOp>(op.getLoc (), op.getResult ().getType (), width, height, newMin,
347+ newMax, sparsity, seed);
348+ }
349+ mlir::daphne::RandMatrixOp pushDownRandomIntoEwPow (mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwPowOp op,
350+ mlir::Value scalar, mlir::PatternRewriter &rewriter) {
351+ auto max = randOp.getMax ();
352+ auto min = randOp.getMin ();
353+ auto height = randOp.getNumRows ();
354+ auto width = randOp.getNumCols ();
355+ auto sparsity = randOp.getSparsity ();
356+ auto seed = randOp.getSeed ();
357+ mlir::daphne::EwPowOp newMax = rewriter.create <mlir::daphne::EwPowOp>(op.getLoc (), max, scalar);
358+ mlir::daphne::EwPowOp newMin = rewriter.create <mlir::daphne::EwPowOp>(op.getLoc (), min, scalar);
359+ return rewriter.create <mlir::daphne::RandMatrixOp>(op.getLoc (), op.getResult ().getType (), width, height, newMin,
360+ newMax, sparsity, seed);
361+ }
362+ mlir::daphne::RandMatrixOp pushDownRandomIntoEwLog (mlir::daphne::RandMatrixOp randOp, mlir::daphne::EwLogOp op,
363+ mlir::Value scalar, mlir::PatternRewriter &rewriter) {
364+ auto max = randOp.getMax ();
365+ auto min = randOp.getMin ();
366+ auto height = randOp.getNumRows ();
367+ auto width = randOp.getNumCols ();
368+ auto sparsity = randOp.getSparsity ();
369+ auto seed = randOp.getSeed ();
370+ mlir::daphne::EwLogOp newMax = rewriter.create <mlir::daphne::EwLogOp>(op.getLoc (), max, scalar);
371+ mlir::daphne::EwLogOp newMin = rewriter.create <mlir::daphne::EwLogOp>(op.getLoc (), min, scalar);
372+ return rewriter.create <mlir::daphne::RandMatrixOp>(op.getLoc (), op.getResult ().getType (), width, height, newMin,
373+ newMax, sparsity, seed);
374+ }
310375
311376/* *
312377 * @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string,
@@ -431,6 +496,20 @@ mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op
431496 rewriter.replaceOp (op, {newFill});
432497 return mlir::success ();
433498 }
499+ // This will check for the rand operation to push down the arithmetic inside
500+ // of it
501+ mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
502+ mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
503+ if (lhsRand && rhsIsSca) {
504+ auto newRand = pushDownRandomIntoEwSub (lhsRand, op, rhs, rewriter);
505+ rewriter.replaceOp (op, {newRand});
506+ return mlir::success ();
507+ }
508+ if (rhsRand && lhsIsSca) {
509+ auto newRand = pushDownRandomIntoEwSub (rhsRand, op, lhs, rewriter);
510+ rewriter.replaceOp (op, {newRand});
511+ return mlir::success ();
512+ }
434513
435514 if (lhsIsSca && !rhsIsSca) {
436515 rewriter.replaceOpWithNewOp <mlir::daphne::EwAddOp>(
@@ -475,6 +554,21 @@ mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize(mlir::daphne::EwMulOp op
475554 return mlir::success ();
476555 }
477556
557+ // This will check for the rand operation to push down the arithmetic inside
558+ // of it
559+ mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
560+ mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
561+ if (lhsRand && rhsIsSca) {
562+ auto newRand = pushDownRandomIntoEwMul (lhsRand, op, rhs, rewriter);
563+ rewriter.replaceOp (op, {newRand});
564+ return mlir::success ();
565+ }
566+ if (rhsRand && lhsIsSca) {
567+ auto newRand = pushDownRandomIntoEwMul (rhsRand, op, lhs, rewriter);
568+ rewriter.replaceOp (op, {newRand});
569+ return mlir::success ();
570+ }
571+
478572 if (lhsIsSca && !rhsIsSca) {
479573 rewriter.replaceOpWithNewOp <mlir::daphne::EwMulOp>(op, op.getResult ().getType (), rhs, lhs);
480574 return mlir::success ();
@@ -515,6 +609,21 @@ mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize(mlir::daphne::EwDivOp op
515609 return mlir::success ();
516610 }
517611
612+ // This will check for the rand operation to push down the arithmetic inside
613+ // of it
614+ mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
615+ mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
616+ if (lhsRand && rhsIsSca) {
617+ auto newRand = pushDownRandomIntoEwDiv (lhsRand, op, rhs, rewriter);
618+ rewriter.replaceOp (op, {newRand});
619+ return mlir::success ();
620+ }
621+ if (rhsRand && lhsIsSca) {
622+ auto newRand = pushDownRandomIntoEwDiv (rhsRand, op, lhs, rewriter);
623+ rewriter.replaceOp (op, {newRand});
624+ return mlir::success ();
625+ }
626+
518627 const bool rhsIsFP = llvm::isa<mlir::FloatType>(CompilerUtils::getValueType (rhs.getType ()));
519628 if (lhsIsSca && !rhsIsSca && rhsIsFP) {
520629 rewriter.replaceOpWithNewOp <mlir::daphne::EwMulOp>(
@@ -594,6 +703,20 @@ mlir::LogicalResult mlir::daphne::EwPowOp::canonicalize(mlir::daphne::EwPowOp op
594703 rewriter.replaceOp (op, {newFill});
595704 return mlir::success ();
596705 }
706+ // This will check for the rand operation to push down the arithmetic inside
707+ // of it
708+ mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
709+ mlir::daphne::RandMatrixOp rhsRand = rhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
710+ if (lhsRand && rhsIsSca) {
711+ auto newRand = pushDownRandomIntoEwPow (lhsRand, op, rhs, rewriter);
712+ rewriter.replaceOp (op, {newRand});
713+ return mlir::success ();
714+ }
715+ if (rhsRand && lhsIsSca) {
716+ auto newRand = pushDownRandomIntoEwPow (rhsRand, op, lhs, rewriter);
717+ rewriter.replaceOp (op, {newRand});
718+ return mlir::success ();
719+ }
597720 return mlir::failure ();
598721}
599722
0 commit comments