diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index 020aabd9db6df..a617029ce470f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -1,16 +1,10 @@ -//===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===// +//===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file Std transformations to expand Divs operation to help for the -// lowering to LLVM. Currently implemented transformations are Ceil and Floor -// for Signed Integers. -// -//===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/Transforms/Passes.h" @@ -33,44 +27,6 @@ using namespace mlir; namespace { -/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with -/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for -/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded -/// code. -/// -/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 -/// -/// will be lowered to -/// -/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> { -/// ^bb0(%current: f32): -/// %1 = arith.maximumf %current, %fval : f32 -/// memref.atomic_yield %1 : f32 -/// } -struct AtomicRMWOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(memref::AtomicRMWOp op, - PatternRewriter &rewriter) const final { - auto loc = op.getLoc(); - auto genericOp = rewriter.create( - loc, op.getMemref(), op.getIndices()); - OpBuilder bodyBuilder = - OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); - - Value lhs = genericOp.getCurrentValue(); - Value rhs = op.getValue(); - - Value arithOp = - mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs); - bodyBuilder.create(loc, arithOp); - - rewriter.replaceOp(op, genericOp.getResult()); - return success(); - } -}; - /// Converts `memref.reshape` that has a target shape of a statically-known /// size to `memref.reinterpret_cast`. struct MemRefReshapeOpConverter : public OpRewritePattern { @@ -139,13 +95,6 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase { ConversionTarget target(ctx); target.addLegalDialect(); - target.addDynamicallyLegalOp( - [](memref::AtomicRMWOp op) { - constexpr std::array shouldBeExpandedKinds = { - arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf, - arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf}; - return !llvm::is_contained(shouldBeExpandedKinds, op.getKind()); - }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { return !cast(op.getShape().getType()).hasStaticShape(); }); @@ -158,6 +107,5 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase { } // namespace void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir index 65932b5814a66..fc8db546d918d 100644 --- a/mlir/test/Dialect/MemRef/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -1,42 +1,10 @@ // RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s -// CHECK-LABEL: func @atomic_rmw_to_generic -// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) -func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - return %a : f32 -} -// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { -// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): -// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32 -// CHECK: memref.atomic_yield [[MAXIMUM]] : f32 -// CHECK: } -// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { -// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): -// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32 -// CHECK: memref.atomic_yield [[MINIMUM]] : f32 -// CHECK: } -// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { -// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): -// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32 -// CHECK: memref.atomic_yield [[MAXNUM]] : f32 -// CHECK: } -// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { -// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): -// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32 -// CHECK: memref.atomic_yield [[MINNUM]] : f32 -// CHECK: } -// CHECK: return [[RESULT]] : f32 - -// ----- - // CHECK-LABEL: func @atomic_rmw_no_conversion -func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { +func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> (f32, f32) { %x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - return %x : f32 + %y = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %x, %y : f32, f32 } // CHECK-NOT: generic_atomic_rmw