Skip to content

Commit 4bc495b

Browse files
authored
[Codegen] Always use ? for non-zero offsets (#19952)
In order to rewrite subspans to buffer descriptors, we might need to be able to fold offsets into the buffer descriptors. This means that we need to be able to replace an offset with a different one (specifically 0) because the offset will be applied to the base pointer during buffer casts. If the offset were dynamic, we can always memref.cast the dynamic-ness of the offset back in, but we can't replace a static offset with a different static offset. Therefore, never create buffers that have a static non-zero offset during bufferization.
1 parent d81bb13 commit 4bc495b

File tree

5 files changed

+28
-25
lines changed

5 files changed

+28
-25
lines changed

compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ struct FlattenBindingSubspan final
270270
if (byteOffset && !matchPattern(byteOffset, m_Zero())) {
271271
elementOffset = convertByteOffsetToElementOffset(
272272
rewriter, loc, byteOffset, oldType.getElementType());
273+
// The element offset needs to look dynamic.
274+
elementOffset =
275+
getValueOrCreateConstantIndexOp(rewriter, loc, elementOffset);
273276
AffineExpr s0, s1;
274277
bindSymbols(rewriter.getContext(), s0, s1);
275278
linearShape = affine::makeComposedFoldedAffineApply(

compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ func.func @elementwise() {
212212
// CHECK: func.func @elementwise()
213213
// CHECK-DAG: %[[CST_TENSOR:.+]] = arith.constant dense_resource<__elided__> : tensor<1x10xf32>
214214
// CHECK-DAG: %[[CST_BUF:.+]] = bufferization.to_memref %[[CST_TENSOR]]
215-
// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>
216-
// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>
215+
// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
216+
// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
217217
// CHECK: scf.for
218218
// CHECK-DAG: %[[SUB_IN1:.+]] = memref.subview %[[IN_BUF]]
219219
// CHECK-DAG: %[[SUB_OUT1:.+]] = memref.subview %[[OUT_BUF]]
@@ -2589,8 +2589,8 @@ func.func @reduction_ew() {
25892589
}
25902590

25912591
// CHECK: func.func @reduction_ew
2592-
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1001xf32, strided<[1], offset: 1280>, #hal.descriptor_type<storage_buffer>>
2593-
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1x1001xf32, strided<[1001, 1], offset: 1280>, #hal.descriptor_type<storage_buffer>>
2592+
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1001xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
2593+
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1x1001xf32, strided<[1001, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
25942594
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) : memref<1x1001xf32, #hal.descriptor_type<storage_buffer>>
25952595

25962596
// -----
@@ -2714,7 +2714,7 @@ func.func @sub_byte_bufferize_with_offset() {
27142714
// CHECK-LABEL: func.func @sub_byte_bufferize_with_offset()
27152715
// CHECK: %[[C64:.+]] = arith.constant 64 : index
27162716
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0)
2717-
// CHECK-SAME: memref<64xi4, strided<[1], offset: 128>
2717+
// CHECK-SAME: memref<64xi4, strided<[1], offset: ?>
27182718

27192719
// -----
27202720

compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,13 @@ findOrCreateSubspanBuffer(RewriterBase &rewriter,
8989
Value byteOffset = subspanOp.getByteOffset();
9090
MemRefLayoutAttrInterface layoutAttr = {};
9191
if (byteOffset && !matchPattern(byteOffset, m_Zero())) {
92-
OpFoldResult elementOffset = convertByteOffsetToElementOffset(
93-
rewriter, subspanOp->getLoc(), subspanOp.getByteOffset(),
94-
shapedType.getBoundElementType());
95-
std::optional<int64_t> elementOffsetInt =
96-
getConstantIntValue(elementOffset);
97-
if (!elementOffsetInt) {
98-
elementOffsetInt = ShapedType::kDynamic;
99-
}
92+
// Using buffer resources on AMDGPU will require buffers to be relocated to
93+
// offset 0, so any static offset we can compute here might change.
94+
// Therefore, always use a ? for the offset field unless it's known to be 0.
10095
auto tensorType = llvm::cast<RankedTensorType>(shapedType.getBoundType());
10196
SmallVector<int64_t> strides = getStridesFromShape(tensorType.getShape());
10297
layoutAttr = StridedLayoutAttr::get(rewriter.getContext(),
103-
elementOffsetInt.value(), strides);
98+
ShapedType::kDynamic, strides);
10499
}
105100
auto memRefType =
106101
getMemrefTypeForTensor(shapedType, layoutAttr,

compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ func.func @interleave_and_bitcast_lowering() {
5656
%c3 = arith.constant 3 : index
5757
%c4096 = arith.constant 4096 : index
5858
%c8192 = arith.constant 8192 : index
59-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c4096) flags(ReadOnly) : memref<128xi8, strided<[1], offset: 4096>>
59+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c4096) flags(ReadOnly) : memref<128xi8, strided<[1], offset: ?>>
6060
%out_buffer = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c8192) : memref<256x64xi4, strided<[64, 1], offset: 8192>>
61-
%2 = vector.load %0[%c0] : memref<128xi8, strided<[1], offset: 4096>>, vector<2xi8>
61+
%2 = vector.load %0[%c0] : memref<128xi8, strided<[1], offset: ?>>, vector<2xi8>
6262
%3 = vector.bitcast %2 : vector<2xi8> to vector<4xi4>
6363
%4 = vector.insert %3, %cst_0 [3] : vector<4xi4> into vector<4x4xi4>
6464
%5 = vector.bitcast %4 : vector<4x4xi4> to vector<4x2xi8>

compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ hal.executable @abs_ex_dispatch_0 {
1313
func.func @abs_ex_dispatch_0() {
1414
%c0 = arith.constant 0 : index
1515
%c128 = arith.constant 128 : index
16-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: 32>>
16+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: ?>>
1717
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<16xi32>
1818
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<16xf32>
1919
%3 = gpu.block_id x
2020
%4 = gpu.block_dim x
2121
%5 = gpu.thread_id x
2222
%6 = arith.muli %3, %4 : index
2323
%7 = arith.addi %6, %5 : index
24-
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: 32>>
24+
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: ?>>
2525
%10 = memref.load %1[%7] : memref<16xi32>
2626
%11 = arith.sitofp %10 : i32 to f32
2727
%12 = arith.addf %9, %11 : f32
@@ -145,15 +145,15 @@ hal.executable @mixed_type {
145145
func.func @mixed_type() {
146146
%c0 = arith.constant 0 : index
147147
%c128 = arith.constant 128 : index
148-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) : memref<16xf32, strided<[1], offset: 4>>
148+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) : memref<16xf32, strided<[1], offset: ?>>
149149
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c0) : memref<16xi32>
150150
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<16xf32>
151151
%3 = gpu.block_id x
152152
%4 = gpu.block_dim x
153153
%5 = gpu.thread_id x
154154
%6 = arith.muli %3, %4 : index
155155
%7 = arith.addi %6, %5 : index
156-
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: 4>>
156+
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: ?>>
157157
%10 = memref.load %1[%7] : memref<16xi32>
158158
%11 = arith.sitofp %10 : i32 to f32
159159
%12 = arith.addf %9, %11 : f32
@@ -167,8 +167,13 @@ hal.executable @mixed_type {
167167
// CHECK-LABEL: llvm.func @mixed_type
168168
// CHECK-SAME: (%[[ARG0:.+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
169169
// CHECK-SAME: %{{.*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef})
170+
// CHECK: %[[BYTES_PER_BIT:.+]] = llvm.mlir.constant(8 : i64) : i64
171+
// CHECK: %[[BITS_PER_ELEM:.+]] = llvm.mlir.constant(32 : i64) : i64
172+
// CHECK: %[[BYTE_OFFSET:.+]] = llvm.mlir.constant(128 : index) : i64
173+
// CHECK: %[[OFFSET_BITS:.+]] = llvm.mul %[[BYTE_OFFSET]], %[[BYTES_PER_BIT]]
174+
// CHECK: %[[OFFSET_ELEMS:.+]] = llvm.udiv %[[OFFSET_BITS]], %[[BITS_PER_ELEM]]
170175
// CHECK: nvvm.read.ptx.sreg.tid.x
171-
// CHECK: llvm.getelementptr %[[ARG0]][4] : (!llvm.ptr) -> !llvm.ptr, f32
176+
// CHECK: llvm.getelementptr %[[ARG0]][%[[OFFSET_ELEMS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
172177
// CHECK: llvm.fadd
173178

174179
// -----
@@ -282,18 +287,18 @@ hal.executable @check_not_readonly {
282287
%c0 = arith.constant 0 : index
283288
%c128 = arith.constant 128 : index
284289
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<16xi32>
285-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: 32>>
290+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: ?>>
286291
%b11 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) flags(ReadOnly) : memref<16xi32>
287-
%b12 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) : memref<16xf32, strided<[1], offset: 32>>
292+
%b12 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) : memref<16xf32, strided<[1], offset: ?>>
288293
%b21 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) flags(ReadOnly) : memref<16xi32>
289-
%b22 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: 32>>
294+
%b22 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: ?>>
290295
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) : memref<16xf32>
291296
%3 = gpu.block_id x
292297
%4 = gpu.block_dim x
293298
%5 = gpu.thread_id x
294299
%6 = arith.muli %3, %4 : index
295300
%7 = arith.addi %6, %5 : index
296-
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: 32>>
301+
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: ?>>
297302
%10 = memref.load %1[%7] : memref<16xi32>
298303
%11 = arith.sitofp %10 : i32 to f32
299304
%12 = arith.addf %9, %11 : f32

0 commit comments

Comments
 (0)