-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][bufferize] Make drop-equivalent-buffer-results support mult blocks #163388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][bufferize] Make drop-equivalent-buffer-results support mult blocks #163388
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-bufferization Author: lonely eagle (linuxlonelyeagle) ChangesEnable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function. Full diff: https://github.com/llvm/llvm-project/pull/163388.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71a5ffbb..9c300cc347ecf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,38 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the values at the same position in the `returnOps`.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ SmallVector<Value> operands;
+ for (func::ReturnOp returnOp : returnOps) {
+ operands.push_back(returnOp.getOperand(pos));
+ }
+ return operands;
+}
+
+/// Check if the value in operands is equal to the argument.
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
+ func::ReturnOp returnOp = returnOps.front();
// Compute erased results.
- SmallVector<Value> newReturnValues;
+ SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72bedef6c5..641dcf0d59990 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,33 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}
+
+// -----
+
+
+// CHECK-LABEL: func private @mult_return_callee(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ // CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+ // CHECK: ^bb1:
+ // CHECK: return %[[A]] : index
+ // CHECK: ^bb2:
+ // CHECK: return %[[B]] : index
+ cf.cond_br %cond,^a, ^b
+ ^a:
+ return %casted, %a : tensor<10xf32>, index
+ ^b:
+ return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+ // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+ %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
+ return %t_res, %v : tensor<10xf32>, index
+}
|
@llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) ChangesEnable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function. Full diff: https://github.com/llvm/llvm-project/pull/163388.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71a5ffbb..9c300cc347ecf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,38 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the values at the same position in the `returnOps`.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ SmallVector<Value> operands;
+ for (func::ReturnOp returnOp : returnOps) {
+ operands.push_back(returnOp.getOperand(pos));
+ }
+ return operands;
+}
+
+/// Check if the value in operands is equal to the argument.
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
+ func::ReturnOp returnOp = returnOps.front();
// Compute erased results.
- SmallVector<Value> newReturnValues;
+ SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72bedef6c5..641dcf0d59990 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,33 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}
+
+// -----
+
+
+// CHECK-LABEL: func private @mult_return_callee(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ // CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+ // CHECK: ^bb1:
+ // CHECK: return %[[A]] : index
+ // CHECK: ^bb2:
+ // CHECK: return %[[B]] : index
+ cf.cond_br %cond,^a, ^b
+ ^a:
+ return %casted, %a : tensor<10xf32>, index
+ ^b:
+ return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+ // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+ %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
+ return %t_res, %v : tensor<10xf32>, index
+}
|
191d6f0
to
5c4c6de
Compare
Enable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function.