@@ -255,22 +255,59 @@ template <class Operation> bool pushDownUnary(Operation op, mlir::PatternRewrite
255255 }
256256 // This will check for the rand operation to push down the arithmetic inside
257257 // of it
258- if (rand && supportsPushDownLinear ) {
258+ if (rand && supportsPushDown ) {
259259 auto max = rand.getMax ();
260260 auto min = rand.getMin ();
261261 auto height = rand.getNumRows ();
262262 auto width = rand.getNumCols ();
263263 auto sparsity = rand.getSparsity ();
264264 auto seed = rand.getSeed ();
265- auto newMax =
266- rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), max);
267- auto newMin =
268- rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), min);
269265
270- auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
271- op.getLoc (), op.getResult ().getType (), width, height, newMin, newMax, sparsity, seed);
272- rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
273- return true ;
266+ if (supportsPushDownLinear) {
267+ auto newMax =
268+ rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), max);
269+ auto newMin =
270+ rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), min);
271+
272+ auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
273+ op.getLoc (), op.getResult ().getType (), width, height, newMin, newMax, sparsity, seed);
274+ rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
275+ return true ;
276+ }
277+ // Handle the special case of RandMatrixOp and EwAbsOp which only works
278+ // if the max and min are both positive (Abs is no-op) or both are
279+ // negative (swap them)
280+ if constexpr (std::is_same<Operation, mlir::daphne::EwAbsOp>()) {
281+ auto maxValueInt = CompilerUtils::isConstant<int >(max);
282+ auto minValueInt = CompilerUtils::isConstant<int >(min);
283+ auto maxValueDouble = CompilerUtils::isConstant<double >(max);
284+ auto minValueDouble = CompilerUtils::isConstant<double >(min);
285+
286+ // will be int or double. Whichever it isn't will default to 0
287+ // so they can simply be added together here
288+
289+ auto maxValue = maxValueDouble.second + maxValueInt.second ;
290+ auto minValue = minValueDouble.second + minValueInt.second ;
291+ if (minValue >= 0 && maxValue > minValue) {
292+ // simply remove Abs function
293+
294+ auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
295+ op.getLoc (), op.getResult ().getType (), width, height, min, max, sparsity, seed);
296+ rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
297+ return true ;
298+ }
299+ if (minValue <= 0 && maxValue > minValue) {
300+ // swap max and min
301+ auto newMax =
302+ rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), min);
303+ auto newMin =
304+ rewriter.create <Operation>(op.getLoc (), CompilerUtils::getValueType (op.getResult ().getType ()), max);
305+ auto newCombinedOpAfterPushDown = rewriter.create <mlir::daphne::RandMatrixOp>(
306+ op.getLoc (), op.getResult ().getType (), width, height, newMin, newMax, sparsity, seed);
307+ rewriter.replaceOp (op, {newCombinedOpAfterPushDown});
308+ return true ;
309+ }
310+ }
274311 }
275312 return false ;
276313}
@@ -300,7 +337,7 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
300337 }
301338 // This will check for the rand operation to push down the arithmetic inside
302339 // of it
303- if (lhsRand && rhsIsSca && supportsPushDown) {
340+ if (lhsRand && rhsIsSca && supportsPushDown && supportsPushDownLinear ) {
304341 auto max = lhsRand.getMax ();
305342 auto min = lhsRand.getMin ();
306343 auto height = lhsRand.getNumRows ();
@@ -317,7 +354,7 @@ template <class Operation> bool pushDownBinary(Operation op, mlir::PatternRewrit
317354
318355 // This will check for the seq operation to push down the arithmetic inside
319356 // of it
320- if (lhsSeq && rhsIsSca && supportsPushDownLinear) {
357+ if (lhsSeq && rhsIsSca && supportsPushDown && supportsPushDownLinear) {
321358 auto from = lhsSeq.getFrom ();
322359 auto to = lhsSeq.getTo ();
323360 auto inc = lhsSeq.getInc ();
@@ -357,7 +394,6 @@ template <class Operation> bool tryPushDown(Operation op, mlir::PatternRewriter
357394 std::is_same<Operation, mlir::daphne::EwLogOp>() || std::is_same<Operation, mlir::daphne::EwModOp>()
358395
359396 ) {
360- spdlog::warn (" binary" );
361397 return pushDownBinary (op, rewriter);
362398 }
363399 return false ;
@@ -381,7 +417,6 @@ template <class Operation> bool tryPushDown(Operation op, mlir::PatternRewriter
381417mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize (mlir::daphne::EwAddOp op, PatternRewriter &rewriter) {
382418 mlir::Value lhs = op.getLhs ();
383419 mlir::Value rhs = op.getRhs ();
384- const bool rhsIsSca = CompilerUtils::isScaType (rhs.getType ());
385420 if (tryPushDown<mlir::daphne::EwAddOp>(op, rewriter)) {
386421 return mlir::success ();
387422 }
0 commit comments