Skip to content

Commit 753cd7f

Browse files
committed
Address comments
1 parent 64c5041 commit 753cd7f

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,11 @@ struct LowerCoalescedGatherDMAPattern final
408408
// Fix: when any non-outermost source index exceeds its dimension,
409409
// replace the outermost index with sourceShape[0] to force the
410410
// linearized offset past the buffer end → hardware returns 0.
411-
if (inBoundsAttr) {
412-
auto sourceType = cast<MemRefType>(source.getType());
411+
auto sourceType = cast<MemRefType>(source.getType());
412+
if (inBoundsAttr && hasAMDGPUFatRawBufferAddressSpace(sourceType)) {
413413
ArrayRef<int64_t> sourceShape = sourceType.getShape();
414-
Value anyNonOutermostOOB;
414+
Value anyNonOutermostOOB = arith::ConstantOp::create(
415+
rewriter, loc, rewriter.getBoolAttr(false));
415416

416417
for (int64_t dim = 1; dim < sourceType.getRank(); ++dim) {
417418
if (dim >= static_cast<int64_t>(inBoundsAttr->size())) {
@@ -435,26 +436,19 @@ struct LowerCoalescedGatherDMAPattern final
435436
arith::CmpIPredicate::uge,
436437
srcIndices[dim], dimSize);
437438

438-
if (anyNonOutermostOOB) {
439-
anyNonOutermostOOB = arith::OrIOp::create(
440-
rewriter, loc, anyNonOutermostOOB, isOOB);
441-
} else {
442-
anyNonOutermostOOB = isOOB;
443-
}
439+
anyNonOutermostOOB = arith::OrIOp::create(
440+
rewriter, loc, anyNonOutermostOOB, isOOB);
444441
}
445442

446-
if (anyNonOutermostOOB) {
447-
Value oobOuterIdx;
448-
if (ShapedType::isDynamic(sourceShape[0])) {
449-
oobOuterIdx = memref::DimOp::create(rewriter, loc, source, 0);
450-
} else {
451-
oobOuterIdx = arith::ConstantIndexOp::create(rewriter, loc,
452-
sourceShape[0]);
453-
}
454-
srcIndices[0] =
455-
arith::SelectOp::create(rewriter, loc, anyNonOutermostOOB,
456-
oobOuterIdx, srcIndices[0]);
443+
Value oobOuterIdx;
444+
if (ShapedType::isDynamic(sourceShape[0])) {
445+
oobOuterIdx = memref::DimOp::create(rewriter, loc, source, 0);
446+
} else {
447+
oobOuterIdx =
448+
arith::ConstantIndexOp::create(rewriter, loc, sourceShape[0]);
457449
}
450+
srcIndices[0] = arith::SelectOp::create(
451+
rewriter, loc, anyNonOutermostOOB, oobOuterIdx, srcIndices[0]);
458452
}
459453

460454
amdgpu::GatherToLDSOp::create(rewriter, loc, source, srcIndices, dest,

compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,28 +1134,41 @@ func.func @lower_coalesced_dma_4x64_tensor_pad_fusion(
11341134
// CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]]
11351135
// CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SRC_LIN0]] into (4, 64)
11361136
// CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (4, 64)
1137-
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_DELIN0]]#0, %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<1xf32>
1137+
// in_bounds = [false, true]: no non-outermost OOB dims, select is identity.
1138+
// CHECK: %[[FALSE0:.+]] = arith.constant false
1139+
// CHECK: %[[DIM0:.+]] = memref.dim %[[SRC]], %{{.+}}
1140+
// CHECK: %[[FIXED0:.+]] = arith.select %[[FALSE0]], %[[DIM0]], %[[SRC_DELIN0]]#0
1141+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED0]], %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<1xf32>
11381142
//
11391143
// Transfer 2: linearOffset = 64, accesses row 1
11401144
// CHECK: %[[C64:.+]] = arith.constant 64 : index
11411145
// CHECK: %[[SRC_LIN64:.+]] = arith.addi %[[C64]], %[[LANE_OFFSET]]
11421146
// CHECK: %[[SRC_DELIN64:.+]]:2 = affine.delinearize_index %[[SRC_LIN64]] into (4, 64)
11431147
// CHECK: %[[DST_DELIN64:.+]]:2 = affine.delinearize_index %[[C64]] into (4, 64)
1144-
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_DELIN64]]#0, %[[SRC_DELIN64]]#1], %[[DST]][%[[DST_DELIN64]]#0, %[[DST_DELIN64]]#1] : vector<1xf32>
1148+
// CHECK: %[[FALSE1:.+]] = arith.constant false
1149+
// CHECK: %[[DIM1:.+]] = memref.dim %[[SRC]], %{{.+}}
1150+
// CHECK: %[[FIXED1:.+]] = arith.select %[[FALSE1]], %[[DIM1]], %[[SRC_DELIN64]]#0
1151+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED1]], %[[SRC_DELIN64]]#1], %[[DST]][%[[DST_DELIN64]]#0, %[[DST_DELIN64]]#1] : vector<1xf32>
11451152
//
11461153
// Transfer 3: linearOffset = 128, accesses row 2
11471154
// CHECK: %[[C128:.+]] = arith.constant 128 : index
11481155
// CHECK: %[[SRC_LIN128:.+]] = arith.addi %[[C128]], %[[LANE_OFFSET]]
11491156
// CHECK: %[[SRC_DELIN128:.+]]:2 = affine.delinearize_index %[[SRC_LIN128]] into (4, 64)
11501157
// CHECK: %[[DST_DELIN128:.+]]:2 = affine.delinearize_index %[[C128]] into (4, 64)
1151-
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_DELIN128]]#0, %[[SRC_DELIN128]]#1], %[[DST]][%[[DST_DELIN128]]#0, %[[DST_DELIN128]]#1] : vector<1xf32>
1158+
// CHECK: %[[FALSE2:.+]] = arith.constant false
1159+
// CHECK: %[[DIM2:.+]] = memref.dim %[[SRC]], %{{.+}}
1160+
// CHECK: %[[FIXED2:.+]] = arith.select %[[FALSE2]], %[[DIM2]], %[[SRC_DELIN128]]#0
1161+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED2]], %[[SRC_DELIN128]]#1], %[[DST]][%[[DST_DELIN128]]#0, %[[DST_DELIN128]]#1] : vector<1xf32>
11521162
//
11531163
// Transfer 4: linearOffset = 192, accesses row 3
11541164
// CHECK: %[[C192:.+]] = arith.constant 192 : index
11551165
// CHECK: %[[SRC_LIN192:.+]] = arith.addi %[[C192]], %[[LANE_OFFSET]]
11561166
// CHECK: %[[SRC_DELIN192:.+]]:2 = affine.delinearize_index %[[SRC_LIN192]] into (4, 64)
11571167
// CHECK: %[[DST_DELIN192:.+]]:2 = affine.delinearize_index %[[C192]] into (4, 64)
1158-
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_DELIN192]]#0, %[[SRC_DELIN192]]#1], %[[DST]][%[[DST_DELIN192]]#0, %[[DST_DELIN192]]#1] : vector<1xf32>
1168+
// CHECK: %[[FALSE3:.+]] = arith.constant false
1169+
// CHECK: %[[DIM3:.+]] = memref.dim %[[SRC]], %{{.+}}
1170+
// CHECK: %[[FIXED3:.+]] = arith.select %[[FALSE3]], %[[DIM3]], %[[SRC_DELIN192]]#0
1171+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED3]], %[[SRC_DELIN192]]#1], %[[DST]][%[[DST_DELIN192]]#0, %[[DST_DELIN192]]#1] : vector<1xf32>
11591172
// CHECK-NOT: amdgpu.gather_to_lds
11601173
// CHECK-NOT: iree_gpu.coalesced_gather_dma
11611174
iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [false, true] :
@@ -1211,8 +1224,10 @@ func.func @gather_dma_non_outermost_oob_check(
12111224
// CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (4, 8)
12121225
//
12131226
// Bounds check: compare srcIndices[1] >= 6 (source dim 1 size)
1227+
// CHECK: %[[FALSE:.+]] = arith.constant false
12141228
// CHECK: %[[C6:.+]] = arith.constant 6 : index
1215-
// CHECK: %[[OOB:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C6]] : index
1229+
// CHECK: %[[CMP:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C6]] : index
1230+
// CHECK: %[[OOB:.+]] = arith.ori %[[FALSE]], %[[CMP]] : i1
12161231
// Replace outermost index with 4 (source dim 0 size) to force hardware OOB
12171232
// CHECK: %[[C4_OOB:.+]] = arith.constant 4 : index
12181233
// CHECK: %[[FIXED_IDX:.+]] = arith.select %[[OOB]], %[[C4_OOB]], %[[SRC_DELIN0]]#0 : index
@@ -1269,8 +1284,10 @@ func.func @gather_dma_inner_dim_oob_64x62(
12691284
// CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (64, 64)
12701285
//
12711286
// Bounds check: compare srcIndices[1] >= 62 (source inner dim size).
1287+
// CHECK: %[[FALSE:.+]] = arith.constant false
12721288
// CHECK: %[[C62:.+]] = arith.constant 62 : index
1273-
// CHECK: %[[OOB:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C62]] : index
1289+
// CHECK: %[[CMP:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C62]] : index
1290+
// CHECK: %[[OOB:.+]] = arith.ori %[[FALSE]], %[[CMP]] : i1
12741291
// Replace outermost index with 64 (source dim 0 size) to force hardware OOB.
12751292
// CHECK: %[[C64_OOB:.+]] = arith.constant 64 : index
12761293
// CHECK: %[[FIXED_IDX:.+]] = arith.select %[[OOB]], %[[C64_OOB]], %[[SRC_DELIN0]]#0 : index

0 commit comments

Comments
 (0)