Skip to content

Commit 6dc81fd

Browse files
[GPU] Do not do c promotion for unaligned (I)GEMMs (iree-org#21823)
We do not need to do c promotion and the performance with and without seems similar. Here are some GEMM shapes and a conv shape I checked ### GEMMs | Shape (MxNxK) | No C promotion (This PR) (us)| C promotion (us) | |-----------------|---------|---------| | 1023x512x512 | 8 | 10 | | 1023x512x5121 | 302 | 303 | | 1023x256x5121 | 305 | 303 | ### Convs | Command |No c promotion (This PR) (us)| C promotion (us) | |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------|---------| | convbfp16 -n 1 -c 77 -H 7 -W 77 -k 77 -y 1 -x 1 -p 3 -q 3 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1 --in_layout NHWC --out_layout NHWC --fil_layout NHWC | 3.4 | 3.5 | I didnt observe any difference above noise thresholds in the two configuration but the no c promotion has the advantage that under the right configuration we can have larger tile sizes. Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 394ddbc commit 6dc81fd

File tree

7 files changed

+174
-42
lines changed

7 files changed

+174
-42
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,6 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
582582
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();
583583

584584
bool mustBeAligned = true;
585-
bool doCPromotion = false;
586585
std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget(
587586
target, problem, transposedLhs, transposedRhs, isGemm,
588587
/*mustBeAligned*/ true,
@@ -595,10 +594,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
595594
if (!schedule && canSupportUnaligned) {
596595
LDBG() << "Attempting to deduce unaligned TileAndFuse MMA schedulee";
597596
mustBeAligned = false;
598-
doCPromotion = true;
599597
schedule = getMmaScheduleFromProblemAndTarget(
600598
target, problem, transposedLhs, transposedRhs, isGemm, mustBeAligned,
601-
doCPromotion, scaled);
599+
/*doCPromotion=*/false, scaled);
602600
}
603601

604602
if (!schedule) {
@@ -674,27 +672,19 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
674672
attrs.emplace_back(StringAttr::get(context, "subgroup"),
675673
b.getI64ArrayAttr(subgroupTileSizes));
676674
attrs.emplace_back(StringAttr::get(context, "mma_kind"), kind);
677-
if (mustBeAligned) {
678-
Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
679-
SmallVector<Attribute> promotionArray = {useGlobalDma, useGlobalDma};
680-
SmallVector<int64_t> promotionList = {0, 1};
681-
if (scaled) {
682-
promotionArray.append({useGlobalDma, useGlobalDma});
683-
promotionList.append({2, 3});
684-
}
685-
ArrayRef<Attribute> promotionTypes =
686-
useDirectLoad ? ArrayRef<Attribute>(promotionArray)
687-
: ArrayRef<Attribute>{};
688-
GPU::appendPromotedOperandsList(context, attrs, promotionList,
689-
promotionTypes);
690-
} else {
691-
// TODO (nirvedhmeshram, Max191, jerryyin) : Add support so that unaligned
692-
// shapes do not require c promotion.
693-
SmallVector<int64_t> promotionList = {0, 1, 2};
694-
if (scaled) {
695-
promotionList.append({3, 4});
696-
}
697-
GPU::appendPromotedOperandsList(context, attrs, promotionList);
675+
Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
676+
SmallVector<Attribute> promotionArray = {useGlobalDma, useGlobalDma};
677+
SmallVector<int64_t> promotionList = {0, 1};
678+
if (scaled) {
679+
promotionArray.append({useGlobalDma, useGlobalDma});
680+
promotionList.append({2, 3});
681+
}
682+
ArrayRef<Attribute> promotionTypes = useDirectLoad
683+
? ArrayRef<Attribute>(promotionArray)
684+
: ArrayRef<Attribute>{};
685+
GPU::appendPromotedOperandsList(context, attrs, promotionList,
686+
promotionTypes);
687+
if (!mustBeAligned) {
698688
SmallVector<int64_t> paddingTileSizes = workgroupTileSizes;
699689

700690
// Initialize inner and outer padding sizes from reductionTileSizes.
@@ -712,8 +702,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
712702
}
713703
paddingTileSizes[innerKDim] *= kPackFactor;
714704

715-
// Create `padding_conv` attribute when padding convolutions before IGEMM is
716-
// possible, otherwise fallback to pad IGEMM.
705+
// Create `padding_conv` attribute when padding convolutions before IGEMM
706+
// is possible, otherwise fallback to pad IGEMM.
717707
if (auto attr =
718708
getPaddingConvSizes(b, bounds, paddingTileSizes, workgroupTileSizes,
719709
reductionTileSizes, convToIgemmInfo)) {

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ func.func @nhwc_conv_unaligned_mfma() {
9999
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
100100

101101
// GFX942-SAME: padding = [2, 1, 32, 64, 32]
102-
// GFX942-SAME: promote_operands = [0, 1, 2]
102+
// GFX942-SAME: promote_operands = [0, 1]
103103
// GFX942-SAME: reduction = [0, 0, 0, 0, 8]
104104
// GFX942-SAME: subgroup = [2, 1, 1, 1, 0]
105105
// GFX942-SAME: workgroup = [2, 1, 32, 64, 0]
106106

107107
// MI300X-SAME: padding = [2, 1, 32, 32, 32]
108-
// MI300X-SAME: promote_operands = [0, 1, 2]
108+
// MI300X-SAME: promote_operands = [0, 1]
109109
// MI300X-SAME: reduction = [0, 0, 0, 0, 8]
110110
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
111111
// MI300X-SAME: workgroup = [2, 1, 32, 32, 0]
@@ -138,13 +138,13 @@ func.func @nchw_conv_unaligned_mfma() {
138138
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
139139

140140
// GFX942-SAME: padding = [1, 64, 4, 32, 32]
141-
// GFX942-SAME: promote_operands = [0, 1, 2]
141+
// GFX942-SAME: promote_operands = [0, 1]
142142
// GFX942-SAME: reduction = [0, 0, 0, 0, 8]
143143
// GFX942-SAME: subgroup = [1, 2, 2, 1, 0]
144144
// GFX942-SAME: workgroup = [1, 64, 4, 32, 0]
145145

146146
// MI300X-SAME: padding = [1, 32, 2, 32, 32]
147-
// MI300X-SAME: promote_operands = [0, 1, 2]
147+
// MI300X-SAME: promote_operands = [0, 1]
148148
// MI300X-SAME: reduction = [0, 0, 0, 0, 8]
149149
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
150150
// MI300X-SAME: workgroup = [1, 32, 2, 32, 0]
@@ -177,13 +177,13 @@ func.func @conv_nhwc_fhwc_unaligned_channel(%arg0: tensor<16x26x19x287xf16>, %ar
177177
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
178178

179179
// GFX942-SAME: padding = [1, 4, 32, 32, 32]
180-
// GFX942-SAME: promote_operands = [0, 1, 2]
180+
// GFX942-SAME: promote_operands = [0, 1]
181181
// GFX942-SAME: reduction = [0, 0, 0, 0, 2]
182182
// GFX942-SAME: subgroup = [1, 2, 1, 1, 0]
183183
// GFX942-SAME: workgroup = [1, 4, 32, 32, 0]
184184

185185
// MI300X-SAME: padding = [1, 2, 32, 32, 32]
186-
// MI300X-SAME: promote_operands = [0, 1, 2]
186+
// MI300X-SAME: promote_operands = [0, 1]
187187
// MI300X-SAME: reduction = [0, 0, 0, 0, 2]
188188
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
189189
// MI300X-SAME: workgroup = [1, 2, 32, 32, 0]
@@ -215,7 +215,7 @@ func.func @conv_chwn_chwf_unaligned_batch(%arg0: tensor<16x193x129x40xbf16>, %ar
215215
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
216216
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
217217
// CHECK-SAME: padding = [16, 1, 1, 16, 64]
218-
// CHECK-SAME: promote_operands = [0, 1, 2]
218+
// CHECK-SAME: promote_operands = [0, 1]
219219
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
220220
// CHECK-SAME: subgroup = [1, 1, 1, 1, 0]
221221
// CHECK-SAME: workgroup = [16, 1, 1, 16, 0]
@@ -247,13 +247,13 @@ func.func @group_conv_hwgc_gfhwc_unaligned(%arg0: tensor<61x93x16x56xbf16>, %arg
247247
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
248248
// GFX942-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
249249
// GFX942-SAME: padding = [1, 32, 1, 64, 64]
250-
// GFX942-SAME: promote_operands = [0, 1, 2]
250+
// GFX942-SAME: promote_operands = [0, 1]
251251
// GFX942-SAME: reduction = [0, 0, 0, 0, 4]
252252
// GFX942-SAME: subgroup = [1, 1, 0, 1, 0]
253253
// GFX942-SAME: workgroup = [1, 32, 1, 64, 0]
254254

255255
// MI300X-SAME: padding = [1, 32, 1, 64, 64]
256-
// MI300X-SAME: promote_operands = [0, 1, 2]
256+
// MI300X-SAME: promote_operands = [0, 1]
257257
// MI300X-SAME: reduction = [0, 0, 0, 0, 4]
258258
// MI300X-SAME: subgroup = [1, 1, 0, 1, 0]
259259
// MI300X-SAME: workgroup = [1, 32, 1, 64, 0]
@@ -287,13 +287,13 @@ module {
287287
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
288288
// GFX942-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
289289
// GFX942-SAME: padding = [2, 2, 32, 64, 64]
290-
// GFX942-SAME: promote_operands = [0, 1, 2]
290+
// GFX942-SAME: promote_operands = [0, 1]
291291
// GFX942-SAME: reduction = [0, 0, 0, 0, 4]
292292
// GFX942-SAME: subgroup = [2, 1, 1, 2, 0]
293293
// GFX942-SAME: workgroup = [2, 2, 32, 64, 0]
294294

295295
// MI300X-SAME: padding = [1, 2, 32, 32, 64]
296-
// MI300X-SAME: promote_operands = [0, 1, 2]
296+
// MI300X-SAME: promote_operands = [0, 1]
297297
// MI300X-SAME: reduction = [0, 0, 0, 0, 4]
298298
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
299299
// MI300X-SAME: workgroup = [1, 2, 32, 32, 0]

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ func.func @unaligned_to_intrinsic_batched_matmul(%lhs : tensor<12x2x577xf32>, %r
330330
// LATE-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
331331
// LATE: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
332332
// LATE-SAME: padding = [1, 16, 16, 4]
333-
// LATE-SAME: promote_operands = [0, 1, 2]
333+
// LATE-SAME: promote_operands = [0, 1]
334334
// LATE-SAME: reduction = [0, 0, 0, 1]
335335
// LATE-SAME: subgroup = [0, 1, 1, 0]
336336
// LATE-SAME: workgroup = [1, 16, 16, 0]
@@ -357,7 +357,7 @@ func.func @unaligned_matmul_with_two_reduce_dim(%arg0: tensor<196x9x4xf32>, %arg
357357
// LATE: linalg.generic
358358
// LATE-SAME: {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
359359
// LATE-SAME: padding = [16, 1, 16, 4]
360-
// LATE-SAME: promote_operands = [0, 1, 2]
360+
// LATE-SAME: promote_operands = [0, 1]
361361
// LATE-SAME: reduction = [0, 1, 0, 1],
362362
// LATE-SAME: subgroup = [1, 0, 1, 0],
363363
// LATE-SAME: workgroup = [16, 0, 16, 0]}
@@ -437,7 +437,7 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5
437437
// LATE-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
438438
// LATE: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
439439
// LATE-SAME: padding = [1, 16, 64, 4]
440-
// LATE-SAME: promote_operands = [0, 1, 2]
440+
// LATE-SAME: promote_operands = [0, 1]
441441
// LATE-SAME: reduction = [0, 0, 0, 1]
442442
// LATE-SAME: subgroup = [0, 1, 2, 0]
443443
// LATE-SAME: workgroup = [1, 16, 64, 0]

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func.func @small_scaled_matmul(
123123
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>
124124
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
125125
// CHECK-SAME: mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>
126-
// CHECK-SAME: promote_operands = [0, 1, 2, 3, 4]
126+
// CHECK-SAME: promote_operands = [0, 1, 2, 3]
127127
// CHECK-SAME: reduction = [0, 0, 1, 1]
128128
// CHECK-SAME: subgroup = [1, 1, 0, 0]
129129
// CHECK-SAME: workgroup = [16, 16, 0, 0]

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,80 @@ hal.executable private @main {
163163

164164
// -----
165165

166+
#pipeline_layout = #hal.pipeline.layout<bindings = [
167+
#hal.pipeline.binding<storage_buffer, ReadOnly>,
168+
#hal.pipeline.binding<storage_buffer, ReadOnly>,
169+
#hal.pipeline.binding<storage_buffer>
170+
]>
171+
#translation = #iree_codegen.translation_info<pipeline =
172+
LLVMGPUTileAndFuse
173+
workgroup_size = [256, 1, 1]
174+
subgroup_size = 64,
175+
{
176+
gpu_pipeline_options = #iree_gpu.pipeline_options<
177+
prefetch_shared_memory = false,
178+
no_reduce_shared_memory_bank_conflicts = false,
179+
use_igemm_convolution = true>
180+
}>
181+
#config = #iree_gpu.lowering_config<{
182+
padding = [2, 1, 32, 16, 16],
183+
workgroup = [2, 1, 32, 16, 0],
184+
reduction = [0, 0, 0, 0, 1],
185+
subgroup = [1, 1, 1, 1, 0],
186+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
187+
promote_operands = [0, 1]
188+
}>
189+
hal.executable private @main {
190+
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
191+
hal.executable.export public @conv_dispatch_0_conv_2d_nhwc_hwcf_2x17x17x1281x3x3x1281_f16xf16xf32 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
192+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
193+
hal.return %x, %y, %z : index, index, index
194+
}
195+
builtin.module {
196+
func.func @conv_nhwc_unaligned_stride_2_nocpromo() attributes {translation_info = #translation} {
197+
%cst = arith.constant 0.000000e+00 : f32
198+
%c0 = arith.constant 0 : index
199+
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x35x35x1281xf16>> %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<3x3x1281x1281xf16>>
200+
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x17x17x1281xf32>>
201+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 35, 35, 1281], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x35x35x1281xf16>> -> tensor<2x35x35x1281xf16>
202+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1281, 1281], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<3x3x1281x1281xf16>> -> tensor<3x3x1281x1281xf16>
203+
%5 = tensor.empty() : tensor<2x17x17x1281xf32>
204+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x17x17x1281xf32>) -> tensor<2x17x17x1281xf32>
205+
%7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, lowering_config = #config, strides = dense<2> : tensor<2xi64>} ins(%3, %4 : tensor<2x35x35x1281xf16>, tensor<3x3x1281x1281xf16>) outs(%6 : tensor<2x17x17x1281xf32>) -> tensor<2x17x17x1281xf32>
206+
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 17, 17, 1281], strides = [1, 1, 1, 1] : tensor<2x17x17x1281xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x17x17x1281xf32>>
207+
return
208+
}
209+
}
210+
}
211+
}
212+
213+
// CHECK-LABEL: func @conv_nhwc_unaligned
214+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
215+
// CHECK-DAG: %[[C721:.+]] = arith.constant 721 : index
216+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
217+
// CHECK-NOT: memref.alloc() {{.*}}xf32
218+
// CHECK-DAG: memref.alloc() : memref<16x20xf16, #gpu.address_space<workgroup>>
219+
// CHECK-DAG: memref.alloc() : memref<2x1x32x20xf16, #gpu.address_space<workgroup>>
220+
// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
221+
// CHECK-DAG: %[[ASSUMED_B0:.+]] = memref.assume_alignment %[[B0]], 64
222+
// CHECK-DAG: %[[BUF0:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_B0]]
223+
// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
224+
// CHECK-DAG: %[[ASSUMED_B1:.+]] = memref.assume_alignment %[[B1]], 64
225+
// CHECK-DAG: %[[BUF1:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_B1]]
226+
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
227+
// CHECK-DAG: %[[ASSUMED_B2:.+]] = memref.assume_alignment %[[B2]], 64
228+
// CHECK-DAG: %[[BUF2:.+]] = amdgpu.fat_raw_buffer_cast %[[ASSUMED_B2]]
229+
// CHECK: scf.forall ({{.*}}) in (17, 1, 81) {
230+
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C721]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
231+
// CHECK: gpu.barrier
232+
// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<4xf16>
233+
// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4xf16>
234+
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
235+
// CHECK-NOT: scf.for
236+
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
237+
238+
// -----
239+
166240
#pipeline_layout = #hal.pipeline.layout<bindings = [
167241
#hal.pipeline.binding<storage_buffer, "ReadOnly">,
168242
#hal.pipeline.binding<storage_buffer, "ReadOnly">,

0 commit comments

Comments
 (0)