@@ -249,6 +249,7 @@ mlir::daphne::FillOp pushDownFillIntoEwSub(mlir::daphne::FillOp fillOp, mlir::da
249249 mlir::daphne::EwSubOp newSub = rewriter.create <mlir::daphne::EwSubOp>(op.getLoc (), fillValue, scalar);
250250 return rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newSub, width, height);
251251}
252+
252253mlir::daphne::FillOp pushDownFillIntoEwMul (mlir::daphne::FillOp fillOp, mlir::daphne::EwMulOp op, mlir::Value scalar,
253254 mlir::PatternRewriter &rewriter) {
254255 auto fillValue = fillOp.getArg ();
@@ -257,6 +258,7 @@ mlir::daphne::FillOp pushDownFillIntoEwMul(mlir::daphne::FillOp fillOp, mlir::da
257258 mlir::daphne::EwMulOp newMul = rewriter.create <mlir::daphne::EwMulOp>(op.getLoc (), fillValue, scalar);
258259 return rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newMul, width, height);
259260}
261+
260262mlir::daphne::FillOp pushDownFillIntoEwDiv (mlir::daphne::FillOp fillOp, mlir::daphne::EwDivOp op, mlir::Value scalar,
261263 mlir::PatternRewriter &rewriter) {
262264 auto fillValue = fillOp.getArg ();
@@ -265,6 +267,7 @@ mlir::daphne::FillOp pushDownFillIntoEwDiv(mlir::daphne::FillOp fillOp, mlir::da
265267 mlir::daphne::EwDivOp newDiv = rewriter.create <mlir::daphne::EwDivOp>(op.getLoc (), fillValue, scalar);
266268 return rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newDiv, width, height);
267269}
270+
268271mlir::daphne::FillOp pushDownFillIntoEwPow (mlir::daphne::FillOp fillOp, mlir::daphne::EwPowOp op, mlir::Value scalar,
269272 mlir::PatternRewriter &rewriter) {
270273 auto fillValue = fillOp.getArg ();
@@ -273,6 +276,7 @@ mlir::daphne::FillOp pushDownFillIntoEwPow(mlir::daphne::FillOp fillOp, mlir::da
273276 mlir::daphne::EwPowOp newPow = rewriter.create <mlir::daphne::EwPowOp>(op.getLoc (), fillValue, scalar);
274277 return rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newPow, width, height);
275278}
279+ // AMLS_TODO: push down naming needs to be other way around
276280mlir::daphne::FillOp pushDownFillIntoEwMod (mlir::daphne::FillOp fillOp, mlir::daphne::EwModOp op, mlir::Value scalar,
277281 mlir::PatternRewriter &rewriter) {
278282 auto fillValue = fillOp.getArg ();
@@ -281,14 +285,17 @@ mlir::daphne::FillOp pushDownFillIntoEwMod(mlir::daphne::FillOp fillOp, mlir::da
281285 mlir::daphne::EwModOp newMod = rewriter.create <mlir::daphne::EwModOp>(op.getLoc (), fillValue, scalar);
282286 return rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newMod, width, height);
283287}
288+
284289mlir::daphne::FillOp pushDownFillIntoEwLog (mlir::daphne::FillOp fillOp, mlir::daphne::EwLogOp op, mlir::Value scalar,
285290 mlir::PatternRewriter &rewriter) {
286291 auto fillValue = fillOp.getArg ();
287292 auto height = fillOp.getNumRows ();
288293 auto width = fillOp.getNumCols ();
289- // AMLS_TODO: this can lead to error:
294+ // AMLS_TODO: this can lead to error if the log resolves cleanly
295+ // e.g. 8 with base 2
290296 // no kernel for operation `fill` available for the required input types `(si64, index, index)` and output types
291297 // `(!daphne.Matrix<?x?xf64>)
298+ // Problem with Log function?
292299
293300 mlir::daphne::EwLogOp newLog = rewriter.create <mlir::daphne::EwLogOp>(op.getLoc (), fillValue, scalar);
294301 return rewriter.create <mlir::daphne::FillOp>(op.getLoc (), op.getResult ().getType (), newLog, width, height);
@@ -642,18 +649,28 @@ mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize(mlir::daphne::EwDivOp op
642649 */
643650mlir::LogicalResult mlir::daphne::EwLogOp::canonicalize (mlir::daphne::EwLogOp op, PatternRewriter &rewriter) {
644651 // AMLS_TODO: reactivate
645- // mlir::Value lhs = op.getLhs();
646- // mlir::Value rhs = op.getRhs();
647- // // This will check for the fill operation to push down the arithmetic inside
648- // // of it
649- // // Since the rhs is the base, the FillOp can only appear legally in lhs
650- // mlir::daphne::FillOp lhsFill = lhs.getDefiningOp<mlir::daphne::FillOp>();
651- // const bool rhsIsSca = CompilerUtils::isScaType(rhs.getType());
652- // if (lhsFill && rhsIsSca) {
653- // auto newFill = pushDownFillIntoEwLog(lhsFill, op, rhs, rewriter);
654- // rewriter.replaceOp(op, {newFill});
655- // return mlir::success();
656- // }
652+ mlir::Value lhs = op.getLhs ();
653+ mlir::Value rhs = op.getRhs ();
654+ // This will check for the fill operation to push down the arithmetic inside
655+ // of it
656+ // Since the rhs is the base, the FillOp can only appear legally in lhs
657+ mlir::daphne::FillOp lhsFill = lhs.getDefiningOp <mlir::daphne::FillOp>();
658+ const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
659+ if (lhsFill && rhsIsSca) {
660+ auto newFill = pushDownFillIntoEwLog (lhsFill, op, rhs, rewriter);
661+ rewriter.replaceOp (op, {newFill});
662+ return mlir::success ();
663+ }
664+
665+ // This will check for the rand operation to push down the arithmetic inside
666+ // of it
667+ // Since the rhs is the base, the RandOp can only appear legally in lhs
668+ mlir::daphne::RandMatrixOp lhsRand = lhs.getDefiningOp <mlir::daphne::RandMatrixOp>();
669+ if (lhsRand && rhsIsSca) {
670+ auto newRand = pushDownRandomIntoEwLog (lhsRand, op, rhs, rewriter);
671+ rewriter.replaceOp (op, {newRand});
672+ return mlir::success ();
673+ }
657674 return mlir::failure ();
658675}
659676/* *
0 commit comments