Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// insert_slice ops.
std::unique_ptr<Pass> createEmptyTensorEliminationPass();

// Create a pass that enforces read only buffers of the
// relevant function arguments.
std::unique_ptr<Pass> createEnforceImmutableFuncArgsPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -595,4 +595,18 @@ def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
}

def EnforceImmutableFuncArgs : Pass<"enforce-immutable-func-args", "func::FuncOp"> {
let summary = "Enforcing function's arguments immutabilty by inserting allocOps and copy";
let description = [{
This pass allocates a new a buffer for each input argument of the function
which is being written to and marked to be enforced, also copying it into the
allocated buffer.
This will avoid in place memory updates for the function's arguments and
make it immutable/read-only buffer.
}];
let constructor = "mlir::bufferization::createEnforceImmutableFuncArgsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}


#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
OptimizeAllocationLiveness.cpp
EnforceImmutableFuncArgs.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
//-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for optimizing allocation liveness.
// The pass moves the deallocation operation after the last user of the
// allocated buffer.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "enforce-immutable-func-args"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;

namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_ENFORCEIMMUTABLEFUNCARGS
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // namespace bufferization
} // namespace mlir

// Checks if there is any operation which tries to write
// into `buffer`.
// This method assumes buffer has `MemRefType`.
static bool isWrittenTo(Value buffer);

namespace {
/// This pass allocates a new a buffer for each input argument of the function
/// which is being written to, also copying it into the allocated buffer.
/// This will avoid in place memory updates for the kernel's arguments and
/// make them immutable/read-only buffers.
struct EnforceImmutableFuncArgsPass
: public bufferization::impl::EnforceImmutableFuncArgsBase<
EnforceImmutableFuncArgsPass> {
void runOnOperation() final;
};
} // end anonymous namespace.

static bool isWrittenTo(Value buffer) {
assert(isa<MemRefType>(buffer.getType()));

for (auto user : buffer.getUsers()) {
if (hasEffect<MemoryEffects::Write>(user, buffer))
return true;
if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(user)) {
assert(viewLikeOp->getNumResults() == 1);
if (isWrittenTo(viewLikeOp->getResult(0)))
return true;
}
}
return false;
}

void EnforceImmutableFuncArgsPass::runOnOperation() {

func::FuncOp funcOp = getOperation();

LDBG("enforcing immutable function arguments in func " << funcOp.getName());

IRRewriter rewriter(funcOp->getContext());
rewriter.setInsertionPointToStart(&funcOp.getBody().front());
for (auto argument : funcOp.getArguments()) {

auto argType = dyn_cast<MemRefType>(argument.getType());
if (!argType) {
emitError(argument.getLoc(),
"function has argument with non memref type");
return signalPassFailure();
}

if (!isWrittenTo(argument))
continue;

LDBG("Found a function argument is being written to " << argument);
Value allocatedMemref =
rewriter.create<memref::AllocOp>(funcOp.getLoc(), argType);
rewriter.replaceAllUsesWith(argument, allocatedMemref);
rewriter.create<memref::CopyOp>(funcOp.getLoc(), argument, allocatedMemref);
}
}

//===----------------------------------------------------------------------===//
// EnforceImmutableFuncArgs construction
//===----------------------------------------------------------------------===//

std::unique_ptr<Pass>
mlir::bufferization::createEnforceImmutableFuncArgsPass() {
return std::make_unique<EnforceImmutableFuncArgsPass>();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// RUN: mlir-opt --split-input-file --enforce-immutable-func-args %s -o - | FileCheck %s


// CHECK-LABEL: func.func @func_no_input() {
// CHECK: return
// CHECK: }

func.func @func_no_input() {
return
}

// -----

// CHECK-LABEL: func.func private @func_with_returned_argument(
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
// CHECK: return %[[VAL_0]] : memref<1x13x21x3xf32>
// CHECK: }

func.func private @func_with_returned_argument(%arg0: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>) {
return %arg0 : memref<1x13x21x3xf32>
}

// -----

// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
// CHECK: linalg.yield %[[VAL_7]] : f32
// CHECK: }
// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
// CHECK: }

func.func private @func_with_modified_argument_directly(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
outs(%arg0 : memref<1x13x21x3xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
return %alloc : memref<1x13x21x3xf32>
}

// -----

// CHECK-LABEL: func.func private @func_with_modified_argument_directly_and_returned(
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: linalg.yield %[[VAL_6]] : f32
// CHECK: }
// CHECK: return %[[VAL_2]] : memref<1x13x21x3xf32>
// CHECK: }

func.func private @func_with_modified_argument_directly_and_returned(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
outs(%arg0 : memref<1x13x21x3xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
return %arg0 : memref<1x13x21x3xf32>
}

// -----

// CHECK-LABEL: func.func private @func_with_modified_argument_directly_twice(
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
// CHECK: linalg.yield %[[VAL_7]] : f32
// CHECK: }
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32):
// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
// CHECK: linalg.yield %[[VAL_11]] : f32
// CHECK: }
// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
// CHECK: }

func.func private @func_with_modified_argument_directly_twice(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
outs(%arg0 : memref<1x13x21x3xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
outs(%arg0 : memref<1x13x21x3xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
return %alloc : memref<1x13x21x3xf32>
}

// -----

// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xi32, 1>, %[[VAL_1:.*]]: memref<5xi32, 1>, %[[VAL_2:.*]]: memref<5xi32, 1>) -> memref<5xi32, 1> {
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<5xi32, 1>
// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<5xi32, 1> to memref<5xi32, 1>
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] {
// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : i32
// CHECK: memref.store %[[VAL_12]], %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
// CHECK: }
// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<5xi32, 1>
// CHECK: memref.copy %[[VAL_3]], %[[VAL_13]] : memref<5xi32, 1> to memref<5xi32, 1>
// CHECK: return %[[VAL_13]] : memref<5xi32, 1>
// CHECK: }

func.func private @func_with_modified_argument_directly(%arg0: memref<5xi32, 1>, %arg1: memref<5xi32, 1>, %arg2: memref<5xi32, 1>) -> (memref<5xi32, 1>){
%c1 = arith.constant 1 : index
%c5 = arith.constant 5 : index
%c0 = arith.constant 0 : index
scf.for %arg3 = %c0 to %c5 step %c1 {
%0 = memref.load %arg0[%arg3] : memref<5xi32, 1>
%1 = arith.index_cast %0 : i32 to index
%2 = memref.load %arg1[%arg3] : memref<5xi32, 1>
%3 = memref.load %arg2[%1] : memref<5xi32, 1>
%4 = arith.addi %2, %3 : i32
memref.store %4, %arg2[%1] : memref<5xi32, 1>
}
%alloc = memref.alloc() : memref<5xi32, 1>
memref.copy %arg2, %alloc : memref<5xi32, 1> to memref<5xi32, 1>
return %alloc : memref<5xi32, 1>
}

// -----

// CHECK-LABEL: func.func private @func_with_modified_argument_indirectly(
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3x4xf32, 1>) -> memref<3x3x4xf32, 1> {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<3x3x4xf32, 1>
// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<3x3x4xf32, 1> to memref<3x3x4xf32, 1>
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%[[VAL_3]] : memref<3x3x4xf32, 1>) {
// CHECK: ^bb0(%[[VAL_4:.*]]: f32):
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_4]] : f32
// CHECK: linalg.yield %[[VAL_5]] : f32
// CHECK: }
// CHECK: return %[[VAL_3]] : memref<3x3x4xf32, 1>
// CHECK: }

func.func private @func_with_modified_argument_indirectly(%arg0: memref<3x3x4xf32, 1>) -> (memref<3x3x4xf32, 1>) {
%collapse_arg = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
%expand_arg = memref.expand_shape %collapse_arg [[0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]
}
outs(%expand_arg : memref<3x3x4xf32, 1>) {
^bb0(%out: f32):
%0 = arith.addf %out, %out : f32
linalg.yield %0 : f32
}
return %expand_arg: memref<3x3x4xf32, 1>
}

// -----

// CHECK-LABEL: func.func private @func_with_modified_argument_subview(
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x4x4xi32, 1>) -> memref<4x4xi32, 1> {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<2x4x4xi32, 1>
// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<2x4x4xi32, 1> to memref<2x4x4xi32, 1>
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_4]] : memref<4x4xi32, 1>) {
// CHECK: ^bb0(%[[VAL_5:.*]]: i32):
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : i32
// CHECK: linalg.yield %[[VAL_6]] : i32
// CHECK: }
// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<4x4xi32, 1>
// CHECK: memref.copy %[[VAL_4]], %[[VAL_7]] : memref<4x4xi32, 1> to memref<4x4xi32, 1>
// CHECK: return %[[VAL_7]] : memref<4x4xi32, 1>
// CHECK: }

func.func private @func_with_modified_argument_subview(%arg0: memref<2x4x4xi32, 1>) -> ( memref<4x4xi32, 1>){
%subview = memref.subview %arg0[0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
%collapse_shape = memref.collapse_shape %subview [[0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
%cast = memref.cast %collapse_shape : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
}
outs(%cast : memref<4x4xi32, 1>) {
^bb0(%out: i32):
%0 = arith.addi %out, %out : i32
linalg.yield %0 : i32
}
%alloc = memref.alloc() : memref<4x4xi32, 1>
memref.copy %cast, %alloc : memref<4x4xi32, 1> to memref<4x4xi32, 1>
return %alloc : memref<4x4xi32, 1>
}

Loading