Skip to content

Commit 23ddd24

Browse files
sebvincepstarkcdpr
authored andcommitted
[AMDGPU] Cache_swizzle stride for fat raw buffer loads should in bytes (iree-org#22314)
Use stride in bytes for L1 Cache_swizzle as described in CDNA3/4 doc. In the case of #iree_gpu.promote_with_cache_swizzle, we set the stride to 0 if it is not a multiple of 8 bits.
1 parent cbfb33a commit 23ddd24

File tree

5 files changed

+130
-23
lines changed

5 files changed

+130
-23
lines changed

compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ util.func private @pingpong_large_bf16(%lhs_base: !bf16_in_ty, %rhs_base: !bf16_
4343
%rhs_shared_base = memref.alloc() : !bf16_flat_shared
4444

4545
%dim = tensor.dim %lhs_base, %c1 : !bf16_in_ty
46-
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
47-
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
46+
%dim_bytes = arith.muli %dim, %c2 overflow<nsw, nuw>: index
47+
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim_bytes) : !bf16_in_ty
48+
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim_bytes) : !bf16_in_ty
4849

4950
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
5051
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
@@ -266,8 +267,9 @@ util.func private @pingpong_medium_bf16_expanded(%lhs_base: !mexp_in_ty_bf16, %r
266267
%rhs_shared_base = memref.alloc() : !flat_shared_bf16
267268

268269
%dim = tensor.dim %rhs_base, %c1 : !in_ty_bf16
269-
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty_bf16
270-
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_bf16
270+
%dim_bytes = arith.muli %dim, %c2 overflow<nsw, nuw>: index
271+
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim_bytes) : !mexp_in_ty_bf16
272+
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim_bytes) : !in_ty_bf16
271273

272274
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !mflat_shared_bf16
273275
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared_bf16
@@ -453,8 +455,9 @@ util.func private @pingpong_large_bf16_expanded(%lhs_base: !bf16_exp_in_ty, %rhs
453455
%rhs_shared_base = memref.alloc() : !bf16_flat_shared
454456

455457
%dim = tensor.dim %rhs_base, %c1 : !bf16_in_ty
456-
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !bf16_exp_in_ty
457-
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
458+
%dim_bytes = arith.muli %dim, %c2 overflow<nsw, nuw>: index
459+
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim_bytes) : !bf16_exp_in_ty
460+
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim_bytes) : !bf16_in_ty
458461

459462
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
460463
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared

compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ util.func private @pingpong_large_f16(%lhs_base: !in_ty, %rhs_base: !in_ty, %unu
3535
%rhs_shared_base = memref.alloc() : !flat_shared
3636

3737
%dim = tensor.dim %lhs_base, %c1 : !in_ty
38-
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !in_ty
39-
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
38+
%dim_bytes = arith.muli %dim, %c2 overflow<nsw, nuw>: index
39+
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim_bytes) : !in_ty
40+
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim_bytes) : !in_ty
4041

4142
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
4243
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
@@ -256,8 +257,9 @@ util.func private @pingpong_medium_f16_expanded(%lhs_base: !mexp_in_ty, %rhs_bas
256257
%rhs_shared_base = memref.alloc() : !flat_shared
257258

258259
%dim = tensor.dim %rhs_base, %c1 : !in_ty
259-
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty
260-
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
260+
%dim_bytes = arith.muli %dim, %c2 overflow<nsw, nuw>: index
261+
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim_bytes) : !mexp_in_ty
262+
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim_bytes) : !in_ty
261263

262264
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !mflat_shared
263265
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
@@ -443,8 +445,9 @@ util.func private @pingpong_large_f16_expanded(%lhs_base: !exp_in_ty, %rhs_base:
443445
%rhs_shared_base = memref.alloc() : !flat_shared
444446

445447
%dim = tensor.dim %rhs_base, %c1 : !in_ty
446-
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !exp_in_ty
447-
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
448+
%dim_bytes = arith.muli %dim, %c2 overflow<nsw, nuw>: index
449+
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim_bytes) : !exp_in_ty
450+
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim_bytes) : !in_ty
448451

449452
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
450453
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-promote-matmul-operands))" | FileCheck %s
1+
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-promote-matmul-operands),canonicalize)" | FileCheck %s
22

33
#lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>
44

@@ -214,8 +214,88 @@ func.func @promote_with_cache_swizzle(%a: tensor<2x34x34x128xf32>, %b: tensor<2x
214214
// CHECK-LABEL: func.func @promote_with_cache_swizzle
215215
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<2x34x34x128xf32>
216216
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<2x8x256xf32>
217-
// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c128)
218-
// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c256)
217+
// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c512)
218+
// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c1024)
219+
// CHECK: %[[PA:.+]] = iree_linalg_ext.im2col
220+
// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
221+
// CHECK-SAME: ins(%[[SWIZZLE_A]]
222+
// CHECK: %[[PB:.+]] = linalg.copy
223+
// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma
224+
// CHECK-SAME: ins(%[[SWIZZLE_B]]
225+
// CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]]
226+
227+
228+
// -----
229+
230+
#lowering_config = #iree_gpu.lowering_config<{
231+
promote_operands = [0, 1],
232+
promotion_types = [
233+
#iree_gpu.promote_with_cache_swizzle<#iree_gpu.derived_thread_config>,
234+
#iree_gpu.promote_with_cache_swizzle<#iree_gpu.use_global_load_dma>]}>
235+
236+
func.func @promote_with_cache_swizzle_f4(%a: tensor<2x34x34x128xf4E2M1FN>, %b: tensor<2x8x256xf4E2M1FN>) -> tensor<2x128x256xf32> {
237+
%cst = arith.constant 0.000000e+00 : f32
238+
%empty = tensor.empty() : tensor<2x128x256xf32>
239+
%im2col_empty = tensor.empty() : tensor<2x128x8xf4E2M1FN>
240+
241+
%im2col = iree_linalg_ext.im2col
242+
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
243+
m_offset = [0] * [1] k_offset = [0] * [1]
244+
batch_pos = [0] m_pos = [2, 3] k_pos = [1]
245+
input_k_perm = [0, 1, 2]
246+
ins(%a : tensor<2x34x34x128xf4E2M1FN>)
247+
outs(%im2col_empty : tensor<2x128x8xf4E2M1FN>) -> tensor<2x128x8xf4E2M1FN>
248+
249+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<2x128x256xf32>) -> tensor<2x128x256xf32>
250+
%mm = linalg.batch_matmul {lowering_config = #lowering_config}
251+
ins(%im2col, %b : tensor<2x128x8xf4E2M1FN>, tensor<2x8x256xf4E2M1FN>) outs(%fill : tensor<2x128x256xf32>) -> tensor<2x128x256xf32>
252+
return %mm : tensor<2x128x256xf32>
253+
}
254+
255+
// CHECK-LABEL: func.func @promote_with_cache_swizzle_f4
256+
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<2x34x34x128xf4E2M1FN>
257+
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<2x8x256xf4E2M1FN>
258+
// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c64)
259+
// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c128)
260+
// CHECK: %[[PA:.+]] = iree_linalg_ext.im2col
261+
// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
262+
// CHECK-SAME: ins(%[[SWIZZLE_A]]
263+
// CHECK: %[[PB:.+]] = linalg.copy
264+
// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma
265+
// CHECK-SAME: ins(%[[SWIZZLE_B]]
266+
// CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]]
267+
268+
// -----
269+
#lowering_config = #iree_gpu.lowering_config<{
270+
promote_operands = [0, 1],
271+
promotion_types = [
272+
#iree_gpu.promote_with_cache_swizzle<#iree_gpu.derived_thread_config>,
273+
#iree_gpu.promote_with_cache_swizzle<#iree_gpu.use_global_load_dma>]}>
274+
275+
func.func @promote_with_cache_swizzle_f4_no_stride(%a: tensor<2x34x34x129xf4E2M1FN>, %b: tensor<2x8x256xf4E2M1FN>) -> tensor<2x129x256xf32> {
276+
%cst = arith.constant 0.000000e+00 : f32
277+
%empty = tensor.empty() : tensor<2x129x256xf32>
278+
%im2col_empty = tensor.empty() : tensor<2x129x8xf4E2M1FN>
279+
280+
%im2col = iree_linalg_ext.im2col
281+
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
282+
m_offset = [0] * [1] k_offset = [0] * [1]
283+
batch_pos = [0] m_pos = [2, 3] k_pos = [1]
284+
input_k_perm = [0, 1, 2]
285+
ins(%a : tensor<2x34x34x129xf4E2M1FN>)
286+
outs(%im2col_empty : tensor<2x129x8xf4E2M1FN>) -> tensor<2x129x8xf4E2M1FN>
287+
288+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<2x129x256xf32>) -> tensor<2x129x256xf32>
289+
%mm = linalg.batch_matmul {lowering_config = #lowering_config}
290+
ins(%im2col, %b : tensor<2x129x8xf4E2M1FN>, tensor<2x8x256xf4E2M1FN>) outs(%fill : tensor<2x129x256xf32>) -> tensor<2x129x256xf32>
291+
return %mm : tensor<2x129x256xf32>
292+
}
293+
294+
// CHECK-LABEL: func.func @promote_with_cache_swizzle_f4_no_stride
295+
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<2x34x34x129xf4E2M1FN>
296+
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<2x8x256xf4E2M1FN>
297+
// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c0)
298+
// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c128)
219299
// CHECK: %[[PA:.+]] = iree_linalg_ext.im2col
220300
// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
221301
// CHECK-SAME: ins(%[[SWIZZLE_A]]

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ def IREEGPU_PromoteWithCacheSwizzle :
111111
swizzle inserted if possible. For example,
112112

113113
```
114-
%0 = tensor_ext.dispatch.tensor.load : tensor<?x8192>
114+
%0 = tensor_ext.dispatch.tensor.load : tensor<?x4096xf16>
115115
%1 = linalg.matmul ins(%0, ...)
116116
```
117117

118118
Becomes with `#iree_gpu.promote_with_cache_swizzle<#iree_gpu.derived_thread_config>`
119119

120120
```
121-
%0 = tensor_ext.dispatch.tensor.load : tensor<?x8192>
121+
%0 = tensor_ext.dispatch.tensor.load : tensor<?x4096xf16>
122122
%1 = iree_gpu.buffer_resource_cast cache_swizzle(8192)
123123
%2 = linalg.copy lowering_config = #iree_gpu.derived_thread_config
124124
%3 = linalg.matmul ins(%2, ...)

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1212
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1313
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1415
#include "mlir/Dialect/Arith/Utils/Utils.h"
1516
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -111,14 +112,34 @@ Value cacheSwizzlePromotionImpl(OpBuilder &builder, OpOperand &operand,
111112
}
112113

113114
Location loc = promotedValue.getLoc();
114-
// Use the size of the inner most dimension as the cache swizzle value.
115-
// This is a very rudimentary choice, but functions well enough as a
115+
// Use the size in bytes of the inner most dimension as the cache swizzle
116+
// value. This is a very rudimentary choice, but functions well enough as a
116117
// default.
117-
Value cacheSwizzleVal = getValueOrCreateConstantIndexOp(
118+
AffineExpr s0, s1;
119+
bindSymbols(builder.getContext(), s0, s1);
120+
Value dtype =
121+
arith::ConstantIndexOp::create(
122+
builder, loc, tensorType.getElementType().getIntOrFloatBitWidth())
123+
->getResult(0);
124+
125+
OpFoldResult dim = tensor::getMixedSize(builder, loc, bufferCastValue,
126+
tensorType.getRank() - 1);
127+
Value zero =
128+
getValueOrCreateConstantIntOp(builder, loc, builder.getIndexAttr(0));
129+
Value strideBytes = getValueOrCreateConstantIndexOp(
118130
builder, loc,
119-
tensor::getMixedSize(builder, loc, bufferCastValue,
120-
tensorType.getRank() - 1));
121-
131+
affine::makeComposedFoldedAffineApply(builder, loc, (s0 * s1).ceilDiv(8),
132+
{dim, dtype}));
133+
Value strideBitsMod8 = getValueOrCreateConstantIntOp(
134+
builder, loc,
135+
affine::makeComposedFoldedAffineApply(builder, loc, (s0 * s1) % 8,
136+
{dim, dtype}));
137+
// If the stride in bits is not a multiple of 8, set the value to 0. This will
138+
// be ignored by cacheSwizzleStride.
139+
Value cmp = arith::CmpIOp::create(
140+
builder, loc, mlir::arith::CmpIPredicate::eq, strideBitsMod8, zero);
141+
Value cacheSwizzleVal =
142+
arith::SelectOp::create(builder, loc, cmp, strideBytes, zero).getResult();
122143
// Insert the resource cast optimistically. If the input is not castable
123144
// (e.g. another producer) later patterns will drop it anyway as it is treated
124145
// like a hint.

0 commit comments

Comments
 (0)