Skip to content

Commit 5f96735

Browse files
lialanYour Name
authored andcommitted
[GPU] Address PR #23365 review comments.
* Change gfx942 → gfx950 in gpu_convert_to_coalesced_dma tests. * Add in_bounds semantics documentation to CoalescedGatherDMAOp. * Remove hardware-specific references from op verifier comment. * Rewrite misleading "ONE level of extract_slice" fallback comment. * Add inner-dim padding OOB lowering test (64x62xf32 → 64x64xf32). * Fix missing trailing periods on comments.
1 parent 24d1821 commit 5f96735

File tree

5 files changed

+100
-30
lines changed

5 files changed

+100
-30
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ static LogicalResult createDMAInForall(scf::ForallOp threadForallOp,
370370
}
371371
}
372372

373-
// Fallback: original behavior without tensor.pad fusion.
374-
// Only trace through ONE level of extract_slice (the immediate input).
373+
// Fallback: no tensor.pad fusion. The input is an extract_slice from
374+
// tiling; trace through it to get the actual source.
375375
if (!source) {
376376
if (auto extractSlice = input.getDefiningOp<tensor::ExtractSliceOp>()) {
377377
source = extractSlice.getSource();
@@ -506,19 +506,19 @@ struct ConvertPadFusionCopyToCoalescedDMA
506506

507507
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
508508
PatternRewriter &rewriter) const override {
509-
// Only match copies with use_global_load_dma config
509+
// Only match copies with use_global_load_dma config.
510510
auto config = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(copyOp);
511511
if (!config) {
512512
return failure();
513513
}
514514

515-
// Check if this is a tensor.pad fusion case
515+
// Check if this is a tensor.pad fusion case.
516516
auto pad = traceToTensorPad(copyOp.getInputs()[0]);
517517
if (!pad) {
518518
return failure(); // Not a pad fusion case
519519
}
520520

521-
// Check if padding exists (non-zero low/high pad)
521+
// Check if padding exists (non-zero low/high pad).
522522
bool hasPadding = false;
523523
for (auto [low, high] :
524524
llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) {
@@ -894,7 +894,7 @@ struct GPUConvertToCoalescedDMAPass final
894894
bool isPadFusion = false;
895895
if (auto copyOp = dyn_cast<linalg::CopyOp>(op.getOperation())) {
896896
if (auto pad = traceToTensorPad(copyOp.getInputs()[0])) {
897-
// Check if padding exists (non-zero low/high pad)
897+
// Check if padding exists (non-zero low/high pad).
898898
for (auto [low, high] :
899899
llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) {
900900
if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) {

compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,3 +1224,61 @@ func.func @gather_dma_non_outermost_oob_check(
12241224
} {mapping = [#gpu.thread<linear_dim_0>]}
12251225
return
12261226
}
1227+
1228+
// -----
1229+
1230+
// Test: Inner-dim padding OOB check with <64x62xf32> source padded to <64x64xf32>.
1231+
// Only inner dim (dim 1) has padding: 62 → 64. in_bounds = [true, false].
1232+
// Raw buffer OOB is 1D (linear): reading <4 x f32> at [0, 60] would compute a
1233+
// linear offset within the buffer and wrap to [1, 0], [1, 1] instead of returning 0.
1234+
// Fix: when srcIndices[1] >= 62, replace srcIndices[0] with 64 (past buffer end)
1235+
// so the linearized offset exceeds buffer size → hardware returns 0.
1236+
1237+
#executable_target_rocm_hsaco_fb_inner_pad = #hal.executable.target<"rocm",
1238+
"rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
1239+
arch = "gfx950", features = "", wgp = <
1240+
compute = fp32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64],
1241+
max_workgroup_sizes = [1024, 1024, 1024],
1242+
max_thread_count_per_workgroup = 1024,
1243+
max_workgroup_memory_bytes = 65536,
1244+
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
1245+
max_load_instruction_bits = 128, simds_per_wgp = 4,
1246+
vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
1247+
1248+
#translation_64_inner_pad = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
1249+
1250+
// CHECK-LABEL: func.func @gather_dma_inner_dim_oob_64x62
1251+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<64x62xf32, #amdgpu.address_space<fat_raw_buffer>>
1252+
// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<64x64xf32, #gpu.address_space<workgroup>>
1253+
func.func @gather_dma_inner_dim_oob_64x62(
1254+
%source: memref<64x62xf32, #amdgpu.address_space<fat_raw_buffer>>,
1255+
%dest: memref<64x64xf32, #gpu.address_space<workgroup>>)
1256+
attributes {
1257+
hal.executable.target = #executable_target_rocm_hsaco_fb_inner_pad,
1258+
translation_info = #translation_64_inner_pad} {
1259+
// CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (64)
1260+
scf.forall (%arg6) in (64) {
1261+
// Each lane transfers vector<4xf32> (dma_sizes [128] = 128 bits = 4 x f32).
1262+
// CHECK: %[[C4:[a-zA-Z0-9_]+]] = arith.constant 4
1263+
// CHECK: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C4]]
1264+
//
1265+
// Transfer 1: linearOffset = 0
1266+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
1267+
// CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]]
1268+
// CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SRC_LIN0]] into (64, 64)
1269+
// CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (64, 64)
1270+
//
1271+
// Bounds check: compare srcIndices[1] >= 62 (source inner dim size).
1272+
// CHECK: %[[C62:.+]] = arith.constant 62 : index
1273+
// CHECK: %[[OOB:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C62]] : index
1274+
// Replace outermost index with 64 (source dim 0 size) to force hardware OOB.
1275+
// CHECK: %[[C64_OOB:.+]] = arith.constant 64 : index
1276+
// CHECK: %[[FIXED_IDX:.+]] = arith.select %[[OOB]], %[[C64_OOB]], %[[SRC_DELIN0]]#0 : index
1277+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED_IDX]], %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<4xf32>
1278+
// CHECK-NOT: iree_gpu.coalesced_gather_dma
1279+
iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [true, false] :
1280+
memref<64x62xf32, #amdgpu.address_space<fat_raw_buffer>>,
1281+
memref<64x64xf32, #gpu.address_space<workgroup>>, index
1282+
} {mapping = [#gpu.thread<linear_dim_0>]}
1283+
return
1284+
}

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-convert-to-coalesced-dma,canonicalize))" %s --split-input-file | FileCheck %s
22

3-
#gpu_target_copy = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
3+
#gpu_target_copy = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
44
compute = fp32, storage = b32, subgroup = shuffle,
55
max_load_instruction_bits = 128, subgroup_size_choices = [32],
66
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -49,7 +49,7 @@ func.func @copy(%source: tensor<64x512xf32>, %init: tensor<64x512xf32>) -> tenso
4949

5050
// -----
5151

52-
#gpu_target_gather = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
52+
#gpu_target_gather = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
5353
compute = fp32, storage = b32, subgroup = shuffle,
5454
max_load_instruction_bits = 128, subgroup_size_choices = [64],
5555
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -103,7 +103,7 @@ func.func @gather(%source: tensor<64x512xf32>, %indices: tensor<64xi32>, %init:
103103
// Negative test: Skip coalesced DMA when innermost dimension < subgroup size. This is to ensure we do not go down
104104
// the slow path (which is not implemented yet).
105105

106-
#gpu_target_small_inner = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
106+
#gpu_target_small_inner = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
107107
compute = fp32, storage = b32, subgroup = shuffle,
108108
max_load_instruction_bits = 128, subgroup_size_choices = [64],
109109
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -176,7 +176,7 @@ func.func @copy_not_aligned_to_dma(%source_buffer: memref<320xbf16, #amdgpu.addr
176176
// - Instead, we should tile rows to 16 (64/4) and keep columns whole (128)
177177
// This ensures subviews are contiguous in memory.
178178

179-
#gpu_target_contiguous = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
179+
#gpu_target_contiguous = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
180180
compute = fp32, storage = b32, subgroup = shuffle,
181181
max_load_instruction_bits = 128, subgroup_size_choices = [64],
182182
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -236,7 +236,7 @@ func.func @copy_prefer_contiguous_subview(%source: tensor<64x128xf32>, %init: te
236236
// When output comes from tensor.empty(), we can use total elements instead of
237237
// innermost dimension for the size check, enabling coalesced DMA.
238238

239-
#gpu_target_linearize = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
239+
#gpu_target_linearize = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
240240
compute = fp32, storage = b32, subgroup = shuffle,
241241
max_load_instruction_bits = 128, subgroup_size_choices = [64],
242242
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -296,7 +296,7 @@ func.func @copy_small_innermost_linearized(%source: tensor<128x16xf32>) -> tenso
296296
// Test: 1D tensor copy distributes warps across the single dimension.
297297
// This tests the 1D tile size computation logic for flattened copies.
298298

299-
#gpu_target_1d = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
299+
#gpu_target_1d = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
300300
compute = fp32, storage = b32, subgroup = shuffle,
301301
max_load_instruction_bits = 128, subgroup_size_choices = [64],
302302
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -358,7 +358,7 @@ func.func @copy_1d_tensor(%source: tensor<2048xf32>) -> tensor<2048xf32>
358358
// 1. Innermost dim (16) < minElementsPerTransfer (64)
359359
// 2. Output is a function argument, not tensor.empty, so we can't linearize
360360

361-
#gpu_target_no_linearize = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
361+
#gpu_target_no_linearize = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
362362
compute = fp32, storage = b32, subgroup = shuffle,
363363
max_load_instruction_bits = 128, subgroup_size_choices = [64],
364364
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -395,7 +395,7 @@ func.func @copy_small_innermost_no_linearize(%source: tensor<128x16xf32>, %dest:
395395
// The copy should be converted to coalesced DMA when the input comes from an
396396
// extract_slice with contiguous innermost dimensions.
397397

398-
#gpu_target_extract_input = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
398+
#gpu_target_extract_input = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
399399
compute = fp32, storage = b32, subgroup = shuffle,
400400
max_load_instruction_bits = 128, subgroup_size_choices = [64],
401401
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -451,7 +451,7 @@ func.func @copy_with_extract_slice_input(%large_source: tensor<256x128xf32>) ->
451451
// When linalg.copy reads from tensor.pad, trace through to the original source
452452
// and set in_bounds attribute based on padding.
453453

454-
#gpu_target_pad = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
454+
#gpu_target_pad = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
455455
compute = fp32, storage = b32, subgroup = shuffle,
456456
max_load_instruction_bits = 128, subgroup_size_choices = [64],
457457
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -467,24 +467,24 @@ func.func @copy_with_extract_slice_input(%large_source: tensor<256x128xf32>) ->
467467
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4x64xf32>
468468
func.func @copy_with_tensor_pad_fusion(%source: tensor<121x64xf32>, %init: tensor<4x64xf32>, %off: index, %sz: index, %high: index) -> tensor<4x64xf32>
469469
attributes {hal.executable.target = #exec_target_pad, translation_info = #translation_pad} {
470-
// Extract a dynamic slice
470+
// Extract a dynamic slice.
471471
%extracted = tensor.extract_slice %source[%off, 0] [%sz, 64] [1, 1]
472472
: tensor<121x64xf32> to tensor<?x64xf32>
473473

474-
// Pad to static size (only M dimension has padding)
474+
// Pad to static size (only M dimension has padding).
475475
%cst = arith.constant 0.0 : f32
476476
%padded = tensor.pad %extracted low[0, 0] high[%high, 0] {
477477
^bb0(%arg0: index, %arg1: index):
478478
tensor.yield %cst : f32
479479
} : tensor<?x64xf32> to tensor<4x64xf32>
480480

481-
// Copy from padded tensor
481+
// Copy from padded tensor.
482482
%result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma}
483483
ins(%padded : tensor<4x64xf32>)
484484
outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
485485

486-
// Key check: tensor.pad is fused - source is the extract_slice result, not the padded tensor
487-
// in_bounds = [false, true] because M dim has dynamic padding, K dim has no padding
486+
// Key check: tensor.pad is fused - source is the extract_slice result, not the padded tensor.
487+
// in_bounds = [false, true] because M dim has dynamic padding, K dim has no padding.
488488
// CHECK: %[[EXTRACTED:.+]] = tensor.extract_slice %[[SRC]]
489489
// CHECK: scf.forall {{.*}} shared_outs(%[[OUTER_INIT:.+]] = %[[INIT]])
490490
// CHECK: scf.forall (%[[LANE:.+]]) in (64) shared_outs(%[[INNER_INIT:.+]] = %[[OUTER_INIT]])
@@ -504,7 +504,7 @@ func.func @copy_with_tensor_pad_fusion(%source: tensor<121x64xf32>, %init: tenso
504504
// operates on the full padded buffer shape, not on smaller subviews.
505505
// This is critical for correct delinearization in the lowering pass.
506506

507-
#gpu_target_pad_multi_warp = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
507+
#gpu_target_pad_multi_warp = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
508508
compute = fp32, storage = b32, subgroup = shuffle,
509509
max_load_instruction_bits = 128, subgroup_size_choices = [64],
510510
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
@@ -520,18 +520,18 @@ func.func @copy_with_tensor_pad_fusion(%source: tensor<121x64xf32>, %init: tenso
520520
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4x64xf32>
521521
func.func @copy_with_tensor_pad_fusion_multi_warp(%source: tensor<121x64xf32>, %init: tensor<4x64xf32>, %off: index, %sz: index, %high: index) -> tensor<4x64xf32>
522522
attributes {hal.executable.target = #exec_target_pad_multi_warp, translation_info = #translation_pad_multi_warp} {
523-
// Extract a dynamic slice
523+
// Extract a dynamic slice.
524524
%extracted = tensor.extract_slice %source[%off, 0] [%sz, 64] [1, 1]
525525
: tensor<121x64xf32> to tensor<?x64xf32>
526526

527-
// Pad to static size (only M dimension has padding)
527+
// Pad to static size (only M dimension has padding).
528528
%cst = arith.constant 0.0 : f32
529529
%padded = tensor.pad %extracted low[0, 0] high[%high, 0] {
530530
^bb0(%arg0: index, %arg1: index):
531531
tensor.yield %cst : f32
532532
} : tensor<?x64xf32> to tensor<4x64xf32>
533533

534-
// Copy from padded tensor with 4 warps (256/64=4)
534+
// Copy from padded tensor with 4 warps (256/64=4).
535535
%result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma}
536536
ins(%padded : tensor<4x64xf32>)
537537
outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
@@ -570,7 +570,7 @@ func.func @copy_with_tensor_pad_fusion_multi_warp(%source: tensor<121x64xf32>, %
570570
// If a DWORD is partially out-of-bounds, the entire DWORD returns zero,
571571
// causing incorrect results. We bail out to avoid the slow path.
572572

573-
#gpu_target_pad_unaligned = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
573+
#gpu_target_pad_unaligned = #iree_gpu.target<arch = "gfx950", features = "", wgp = <
574574
compute = fp32, storage = b32, subgroup = shuffle,
575575
max_load_instruction_bits = 128, subgroup_size_choices = [64],
576576
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void CoalescedGatherDMAOp::getEffects(
218218
Value source = getSource();
219219
Value init = getInit();
220220

221-
// The operation reads from the source
221+
// The operation reads from the source.
222222
if (isa<MemRefType>(source.getType())) {
223223
effects.emplace_back(MemoryEffects::Read::get(),
224224
&getOperation()->getOpOperand(sourceOperandIdx),
@@ -235,7 +235,7 @@ void CoalescedGatherDMAOp::getEffects(
235235
SideEffects::DefaultResource::get());
236236
} else if (isa<RankedTensorType>(init.getType()) &&
237237
getOperation()->getNumResults() == 0) {
238-
// Tensor combiner case: declare write effect to prevent DCE
238+
// Tensor combiner case: declare write effect to prevent DCE.
239239
effects.emplace_back(MemoryEffects::Write::get(),
240240
&getOperation()->getOpOperand(initOperandIdx),
241241
SideEffects::DefaultResource::get());
@@ -339,9 +339,8 @@ LogicalResult CoalescedGatherDMAOp::verify() {
339339
}
340340

341341
// If in_bounds is present and this dimension allows OOB (in_bounds=false),
342-
// skip the size matching check. For non-outermost dimensions, the lowering
343-
// adds explicit bounds checks since raw buffer OOB only provides 1D
344-
// (linear) clamping, not per-dimension clamping.
342+
// skip the size matching check. The source may be smaller than init along
343+
// this dimension, and reads beyond the source extent return zero.
345344
if (inBoundsAttr) {
346345
auto inBoundsArray = *inBoundsAttr;
347346
if (dim < inBoundsArray.size()) {

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@ def IREEGPU_CoalescedGatherDMAOp : Op<IREEGPU_Dialect, "coalesced_gather_dma", [
300300
* `lane`: The lane that specifies the coalescing store's offset within the
301301
workgroup/shared memory.
302302

303+
## In-Bounds Attribute
304+
305+
The optional `in_bounds` attribute is a boolean array with one entry per
306+
dimension of `init`. When not present, all dimensions are treated as
307+
in-bounds (source and init must have matching sizes for non-indexed dims).
308+
309+
When present, `in_bounds[i] = false` indicates that the source may be
310+
smaller than init along dimension `i`. Reads beyond the source extent
311+
return zero (padding semantics). This enables fusion of `tensor.pad`
312+
with zero padding into the DMA operation.
313+
314+
`in_bounds[i] = true` means the source and init sizes match along that
315+
dimension, and no padding is needed.
303316

304317
## Example of a single subgroup using coalesced_gather_dma in copy mode
305318
for transferring tensor<4x128xf32>, with an intended DMA width of 128 bits

0 commit comments

Comments
 (0)