1- // ===- StdExpandDivs .cpp - Code to prepare Std for lowering Divs to LLVM -===//
1+ // ===- ExpandDivs .cpp - Expansion patterns for MemRef operations --------- -===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
8- //
9- // This file Std transformations to expand Divs operation to help for the
10- // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11- // for Signed Integers.
12- //
13- // ===----------------------------------------------------------------------===//
148
159#include " mlir/Dialect/MemRef/Transforms/Passes.h"
1610
@@ -33,44 +27,6 @@ using namespace mlir;
3327
3428namespace {
3529
36- // / Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
37- // / AtomicRMWOpLowering pattern, such as minimum and maximum operations for
38- // / floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
39- // / code.
40- // /
41- // / %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
42- // /
43- // / will be lowered to
44- // /
45- // / %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
46- // / ^bb0(%current: f32):
47- // / %1 = arith.maximumf %current, %fval : f32
48- // / memref.atomic_yield %1 : f32
49- // / }
50- struct AtomicRMWOpConverter : public OpRewritePattern <memref::AtomicRMWOp> {
51- public:
52- using OpRewritePattern::OpRewritePattern;
53-
54- LogicalResult matchAndRewrite (memref::AtomicRMWOp op,
55- PatternRewriter &rewriter) const final {
56- auto loc = op.getLoc ();
57- auto genericOp = rewriter.create <memref::GenericAtomicRMWOp>(
58- loc, op.getMemref (), op.getIndices ());
59- OpBuilder bodyBuilder =
60- OpBuilder::atBlockEnd (genericOp.getBody (), rewriter.getListener ());
61-
62- Value lhs = genericOp.getCurrentValue ();
63- Value rhs = op.getValue ();
64-
65- Value arithOp =
66- mlir::arith::getReductionOp (op.getKind (), bodyBuilder, loc, lhs, rhs);
67- bodyBuilder.create <memref::AtomicYieldOp>(loc, arithOp);
68-
69- rewriter.replaceOp (op, genericOp.getResult ());
70- return success ();
71- }
72- };
73-
7430// / Converts `memref.reshape` that has a target shape of a statically-known
7531// / size to `memref.reinterpret_cast`.
7632struct MemRefReshapeOpConverter : public OpRewritePattern <memref::ReshapeOp> {
@@ -139,13 +95,6 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
13995 ConversionTarget target (ctx);
14096
14197 target.addLegalDialect <arith::ArithDialect, memref::MemRefDialect>();
142- target.addDynamicallyLegalOp <memref::AtomicRMWOp>(
143- [](memref::AtomicRMWOp op) {
144- constexpr std::array shouldBeExpandedKinds = {
145- arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
146- arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
147- return !llvm::is_contained (shouldBeExpandedKinds, op.getKind ());
148- });
14998 target.addDynamicallyLegalOp <memref::ReshapeOp>([](memref::ReshapeOp op) {
15099 return !cast<MemRefType>(op.getShape ().getType ()).hasStaticShape ();
151100 });
@@ -158,6 +107,5 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
158107} // namespace
159108
160109void mlir::memref::populateExpandOpsPatterns (RewritePatternSet &patterns) {
161- patterns.add <AtomicRMWOpConverter, MemRefReshapeOpConverter>(
162- patterns.getContext ());
110+ patterns.add <MemRefReshapeOpConverter>(patterns.getContext ());
163111}
0 commit comments