Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 2 additions & 54 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
//===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
//===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file Std transformations to expand Divs operation to help for the
// lowering to LLVM. Currently implemented transformations are Ceil and Floor
// for Signed Integers.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/Transforms/Passes.h"

Expand All @@ -33,44 +27,6 @@ using namespace mlir;

namespace {

/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
/// code.
///
/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// ^bb0(%current: f32):
/// %1 = arith.maximumf %current, %fval : f32
/// memref.atomic_yield %1 : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
loc, op.getMemref(), op.getIndices());
OpBuilder bodyBuilder =
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());

Value lhs = genericOp.getCurrentValue();
Value rhs = op.getValue();

Value arithOp =
mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);

rewriter.replaceOp(op, genericOp.getResult());
return success();
}
};

/// Converts `memref.reshape` that has a target shape of a statically-known
/// size to `memref.reinterpret_cast`.
struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Expand Down Expand Up @@ -139,13 +95,6 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
ConversionTarget target(ctx);

target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
constexpr std::array shouldBeExpandedKinds = {
arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
});
Expand All @@ -158,6 +107,5 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
} // namespace

void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
patterns.getContext());
patterns.add<MemRefReshapeOpConverter>(patterns.getContext());
}
38 changes: 3 additions & 35 deletions mlir/test/Dialect/MemRef/expand-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,42 +1,10 @@
// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s

// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %a : f32
}
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MAXIMUM]] : f32
// CHECK: }
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
// CHECK: }
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MAXNUM]] : f32
// CHECK: }
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MINNUM]] : f32
// CHECK: }
// CHECK: return [[RESULT]] : f32

// -----

// CHECK-LABEL: func @atomic_rmw_no_conversion
func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> (f32, f32) {
%x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
%y = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x, %y : f32, f32
}
// CHECK-NOT: generic_atomic_rmw

Expand Down
Loading