Skip to content

Commit 823c9bc

Browse files
qedawkinskeshavvinayak01
authored andcommitted
[Codegen] Remove to_buffer from bufferization deny list (iree-org#21505)
Bufferization already uses to_buffer as the replacement for unrealized conversion casts, so only having constants on the deny list is sufficient. This allows for inserting to_buffer ops in earlier passes. This also fixes an issue where the bufferization.to_buffer on unbufferized constants don't get marked as readonly. This makes it possible to pass the output of bufferization back into iree-comprehensive-bufferize and get the same result. This is required for copy only dispatches to work as they run bufferization thrice. Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 19824e1 commit 823c9bc

File tree

3 files changed

+64
-3
lines changed

3 files changed

+64
-3
lines changed

compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() {
149149
// it's own logic to handle constants. We'd like to leave the arith.constant
150150
// as is and insert bufferization.to_buffer to convert the tensor to memref.
151151
options.opFilter.denyOperation<arith::ConstantOp>();
152-
options.opFilter.denyOperation<bufferization::ToBufferOp>();
153152

154153
// This type converter converts tensor types to memref types when no exact
155154
// memref type can be inferred from the context.
@@ -250,6 +249,18 @@ void IREEComprehensiveBufferizePass::runOnOperation() {
250249
return signalPassFailure();
251250
}
252251

252+
// All to_buffer ops on single use constants will have already had any
253+
// write conflicts resolved by the analysis, so we can safely mark them as
254+
// read only.
255+
funcOp->walk([](bufferization::ToBufferOp toBuffer) {
256+
if (auto constant =
257+
toBuffer.getTensor().getDefiningOp<arith::ConstantOp>()) {
258+
if (constant->hasOneUse()) {
259+
toBuffer.setReadOnly(true);
260+
}
261+
}
262+
});
263+
253264
// Remove redundant args and unused results.
254265
{
255266
RewritePatternSet patterns(&getContext());

compiler/src/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func.func @concatenate_cst() {
9292

9393
// CHECK-LABEL: func.func @concatenate_cst()
9494
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : tensor<2x3xi32>
95-
// CHECK-DAG: %[[ZERO:.+]] = bufferization.to_buffer %[[CST]] : tensor<2x3xi32> to memref<2x3xi32
95+
// CHECK-DAG: %[[ZERO:.+]] = bufferization.to_buffer %[[CST]] read_only : tensor<2x3xi32> to memref<2x3xi32
9696
// CHECK-DAG: %[[DEST_BINDING:.+]] = hal.interface.binding.subspan
9797
// CHECK-DAG: %[[DEST:.+]] = memref.assume_alignment %[[DEST_BINDING]], 64
9898
// CHECK-DAG: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 2] [2, 3]

compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,7 @@ func.func @bufferize_cst_output_tensor() {
14031403

14041404
// CHECK-DAG: %[[CST1:.+]] = arith.constant -2147483648 : i32
14051405
// CHECK-DAG: %[[CST5:.+]] = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xi32>
1406-
// CHECK: %[[CAST5:.+]] = bufferization.to_buffer %[[CST5]] : tensor<5xi32> to memref<5xi32>
1406+
// CHECK: %[[CAST5:.+]] = bufferization.to_buffer %[[CST5]] read_only : tensor<5xi32> to memref<5xi32>
14071407
// CHECK: %[[INPUT:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) : memref<5xf32, #hal.descriptor_type<storage_buffer>>
14081408
// CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) : memref<i32, #hal.descriptor_type<storage_buffer>>
14091409
// CHECK: linalg.fill ins(%[[CST1]] : i32) outs(%[[OUTPUT]] : memref<i32{{.+}}>)
@@ -3123,3 +3123,53 @@ func.func @transfer_gather(%source : tensor<?x64xf16>, %indices: vector<8xindex>
31233123
// CHECK: %[[C0:.+]] = arith.constant 0 : index
31243124
// CHECK: %[[BUFFER:.+]] = bufferization.to_buffer %[[SOURCE]]
31253125
// CHECK: iree_vector_ext.transfer_gather %[[BUFFER]][%[[C0]], %[[C0]]][%[[INDICES]]: vector<8xindex>, None]
3126+
3127+
// -----
3128+
3129+
func.func @convert_to_buffer() -> memref<6xf32> {
3130+
%alloc = bufferization.alloc_tensor() : tensor<6xf32>
3131+
%memref = bufferization.to_buffer %alloc : tensor<6xf32> to memref<6xf32>
3132+
return %memref : memref<6xf32>
3133+
}
3134+
3135+
// CHECK-LABEL: func.func @convert_to_buffer
3136+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xf32>
3137+
// CHECK: return %[[ALLOC]]
3138+
3139+
// -----
3140+
3141+
func.func @readonly_constant_bufferize() {
3142+
%c0 = arith.constant 0 : index
3143+
%cst = arith.constant dense<0> : tensor<6xi32>
3144+
%0 = hal.interface.binding.subspan
3145+
layout(<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>)
3146+
binding(0) alignment(64) offset(%c0) flags(Indirect)
3147+
: !iree_tensor_ext.dispatch.tensor<readwrite:tensor<6xi32>>
3148+
iree_tensor_ext.dispatch.tensor.store %cst, %0, offsets = [0], sizes = [6], strides = [1]
3149+
: tensor<6xi32> -> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<6xi32>>
3150+
return
3151+
}
3152+
3153+
// CHECK-LABEL: func.func @readonly_constant_bufferize
3154+
// CHECK: %[[CST:.+]] = arith.constant dense<0> : tensor<6xi32>
3155+
// CHECK: %[[MEMREF:.+]] = bufferization.to_buffer %[[CST]] read_only
3156+
// CHECK: linalg.generic {{.*}} ins(%[[MEMREF]]
3157+
3158+
// -----
3159+
3160+
func.func @retry_constant_bufferize() {
3161+
%c0 = arith.constant 0 : index
3162+
%cst = arith.constant dense<0> : tensor<6xi32>
3163+
%0 = bufferization.to_buffer %cst read_only : tensor<6xi32> to memref<6xi32>
3164+
%1 = hal.interface.binding.subspan
3165+
layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>)
3166+
binding(0) alignment(64) offset(%c0) flags(Indirect)
3167+
: memref<6xi32, #hal.descriptor_type<storage_buffer>>
3168+
memref.copy %0, %1 : memref<6xi32> to memref<6xi32, #hal.descriptor_type<storage_buffer>>
3169+
return
3170+
}
3171+
3172+
// CHECK-LABEL: func.func @retry_constant_bufferize
3173+
// CHECK: %[[CST:.+]] = arith.constant dense<0> : tensor<6xi32>
3174+
// CHECK: %[[MEMREF:.+]] = bufferization.to_buffer %[[CST]] read_only
3175+
// CHECK: memref.copy %[[MEMREF]]

0 commit comments

Comments
 (0)