@@ -323,6 +323,22 @@ mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(mlir::daphne::EwAddOp op
323323mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize (mlir::daphne::EwSubOp op, PatternRewriter &rewriter) {
324324 mlir::Value lhs = op.getLhs ();
325325 mlir::Value rhs = op.getRhs ();
326+ // This will check for the fill operation on the left hand side to push down the arithmetic inside
327+ // of it
328+ mlir::daphne::FillOp lhsFill = lhs.getDefiningOp <mlir::daphne::FillOp>();
329+ if (lhsFill) {
330+ auto fillValue = lhsFill.getArg ();
331+ auto height = lhsFill.getNumRows ();
332+ auto width = lhsFill.getNumCols ();
333+ const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
334+ if (rhsIsSca) {
335+ mlir::daphne::EwSubOp newSub = rewriter.create <mlir::daphne::EwSubOp>(op.getLoc (), fillValue, rhs);
336+ mlir::daphne::FillOp newFill =
337+ rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newSub, width, height);
338+ rewriter.replaceOp (op, {newFill});
339+ return mlir::success ();
340+ }
341+ }
326342 const bool lhsIsSca = CompilerUtils::isScaType (lhs.getType ());
327343 const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
328344 if (lhsIsSca && !rhsIsSca) {
@@ -351,6 +367,23 @@ mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op
351367mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize (mlir::daphne::EwMulOp op, PatternRewriter &rewriter) {
352368 mlir::Value lhs = op.getLhs ();
353369 mlir::Value rhs = op.getRhs ();
370+ // This will check for the fill operation on the left hand side to push down the arithmetic inside
371+ // of it
372+ mlir::daphne::FillOp lhsFill = lhs.getDefiningOp <mlir::daphne::FillOp>();
373+ if (lhsFill) {
374+ auto fillValue = lhsFill.getArg ();
375+ auto height = lhsFill.getNumRows ();
376+ auto width = lhsFill.getNumCols ();
377+ const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
378+ if (rhsIsSca) {
379+ mlir::daphne::EwMulOp newMul = rewriter.create <mlir::daphne::EwMulOp>(op.getLoc (), fillValue, rhs);
380+ mlir::daphne::FillOp newFill =
381+ rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newMul , width, height);
382+ rewriter.replaceOp (op, {newFill});
383+ return mlir::success ();
384+ }
385+ }
386+
354387 const bool lhsIsSca = CompilerUtils::isScaType (lhs.getType ());
355388 const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
356389 if (lhsIsSca && !rhsIsSca) {
@@ -376,6 +409,22 @@ mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize(mlir::daphne::EwMulOp op
376409mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize (mlir::daphne::EwDivOp op, PatternRewriter &rewriter) {
377410 mlir::Value lhs = op.getLhs ();
378411 mlir::Value rhs = op.getRhs ();
412+ // This will check for the fill operation on the left hand side to push down the arithmetic inside
413+ // of it
414+ mlir::daphne::FillOp lhsFill = lhs.getDefiningOp <mlir::daphne::FillOp>();
415+ if (lhsFill) {
416+ auto fillValue = lhsFill.getArg ();
417+ auto height = lhsFill.getNumRows ();
418+ auto width = lhsFill.getNumCols ();
419+ const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
420+ if (rhsIsSca) {
421+ mlir::daphne::EwDivOp newDiv = rewriter.create <mlir::daphne::EwDivOp>(op.getLoc (), fillValue, rhs);
422+ mlir::daphne::FillOp newFill =
423+ rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newDiv, width, height);
424+ rewriter.replaceOp (op, {newFill});
425+ return mlir::success ();
426+ }
427+ }
379428 const bool lhsIsSca = CompilerUtils::isScaType (lhs.getType ());
380429 const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
381430 const bool rhsIsFP = llvm::isa<mlir::FloatType>(CompilerUtils::getValueType (rhs.getType ()));
0 commit comments