Skip to content

Commit 344d4bd

Browse files
committed
Use another approach to work.
1 parent 7956470 commit 344d4bd

File tree

4 files changed

+66
-52
lines changed

4 files changed

+66
-52
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,38 @@ static SmallVector<Attribute> getThreadMapping(MLIRContext *ctx) {
5959
return mapping;
6060
}
6161

62+
/// Check if the source of an operation comes directly from global memory.
63+
/// Returns false if the source goes through tensor.pad or other local
64+
/// computation that would prevent using global load DMA.
65+
static bool sourceIsFromGlobalMemory(Operation *op) {
66+
Value source;
67+
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
68+
source = copyOp.getInputs()[0];
69+
} else if (auto gatherOp = dyn_cast<IREE::LinalgExt::GatherOp>(op)) {
70+
source = gatherOp.getSource();
71+
} else {
72+
return false;
73+
}
74+
75+
// Trace through extract_slice operations to find the origin.
76+
while (auto extractOp = source.getDefiningOp<tensor::ExtractSliceOp>()) {
77+
source = extractOp.getSource();
78+
}
79+
80+
// If the source comes from tensor.pad, it's not directly from global memory.
81+
if (source.getDefiningOp<tensor::PadOp>()) {
82+
return false;
83+
}
84+
85+
// Otherwise, assume it's from global memory (e.g., dispatch tensor load).
86+
return true;
87+
}
88+
6289
/// Helper to compute thread number of threads based on translation_info.
6390
/// Uses the subgroup_size from translation_info for thread-level tiling.
6491
static SmallVector<OpFoldResult>
6592
computeThreadNumThreadsImpl(OpBuilder &builder, Operation *op,
6693
RankedTensorType outputType) {
67-
// Check that this operation has the use_global_load_dma config.
68-
auto dmaConfig = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op);
69-
if (!dmaConfig) {
70-
return {};
71-
}
72-
7394
// Get the function containing this operation.
7495
auto funcOp = op->getParentOfType<FunctionOpInterface>();
7596
if (!funcOp) {
@@ -341,6 +362,11 @@ struct ConvertToCoalescedDMABase : public OpRewritePattern<OpTy> {
341362
return failure();
342363
}
343364

365+
// Check that source comes from global memory (not tensor.pad).
366+
if (!sourceIsFromGlobalMemory(op)) {
367+
return failure();
368+
}
369+
344370
SmallVector<OpFoldResult> threadNumThreads =
345371
computeThreadNumThreads(rewriter, op);
346372
if (threadNumThreads.empty()) {
@@ -386,11 +412,8 @@ struct ConvertGatherToCoalescedDMA
386412
return failure();
387413
}
388414

389-
// For gather ops, tile only the innermost dimension to distribute across
390-
// threads.
391-
auto dmaConfig =
392-
getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(gatherOp);
393-
if (!dmaConfig) {
415+
// Check that source comes from global memory (not tensor.pad).
416+
if (!sourceIsFromGlobalMemory(gatherOp)) {
394417
return failure();
395418
}
396419

@@ -400,6 +423,12 @@ struct ConvertGatherToCoalescedDMA
400423
return failure();
401424
}
402425

426+
// Check target supports global load DMA.
427+
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
428+
if (!target || !IREE::GPU::targetSupportsGlobalLoadDMA(target)) {
429+
return failure();
430+
}
431+
403432
// Get subgroup size from translation_info.
404433
std::optional<int64_t> subgroupSize = getSubgroupSize(funcOp);
405434
if (!subgroupSize) {
@@ -418,11 +447,6 @@ struct ConvertGatherToCoalescedDMA
418447
Type elementType = outputType.getElementType();
419448
int64_t elementBits = elementType.getIntOrFloatBitWidth();
420449

421-
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
422-
if (!target) {
423-
return failure();
424-
}
425-
426450
ArrayRef<int64_t> dmaSizes;
427451
if (DenseI64ArrayAttr dmaSizesAttr = target.getWgp().getDmaSizes()) {
428452
dmaSizes = dmaSizesAttr.asArrayRef();
@@ -617,10 +641,6 @@ struct GPUConvertToCoalescedDMAPass final
617641
FailureOr<scf::SCFTilingResult> tileAtSubgroupLevel(IRRewriter &rewriter,
618642
OpTy op) {
619643
MLIRContext *context = &getContext();
620-
auto dmaConfig = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op);
621-
if (!dmaConfig) {
622-
return failure();
623-
}
624644

625645
// Get the function containing this operation.
626646
auto funcOp = op->template getParentOfType<FunctionOpInterface>();
@@ -718,18 +738,26 @@ struct GPUConvertToCoalescedDMAPass final
718738

719739
LogicalResult applySubgroupTiling(FunctionOpInterface funcOp) {
720740
MLIRContext *context = &getContext();
741+
742+
// Check if target supports global load DMA.
743+
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
744+
if (!target || !IREE::GPU::targetSupportsGlobalLoadDMA(target)) {
745+
return success();
746+
}
747+
721748
SmallVector<Operation *> opsToTile;
722749

723-
// Collect all ops with iree_gpu.use_global_load_dma lowering config.
750+
// Collect copy/gather ops that are eligible for coalesced DMA.
724751
// Skip ops that are already inside a warp-mapped forall.
725752
funcOp->walk([&](Operation *op) {
726753
if (isa<linalg::CopyOp, IREE::LinalgExt::GatherOp>(op)) {
727-
auto config = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op);
728-
if (config) {
729-
auto parentForall = op->getParentOfType<scf::ForallOp>();
730-
if (!hasWarpMapping(parentForall)) {
731-
opsToTile.push_back(op);
732-
}
754+
// Check that source comes from global memory (not tensor.pad).
755+
if (!sourceIsFromGlobalMemory(op)) {
756+
return;
757+
}
758+
auto parentForall = op->getParentOfType<scf::ForallOp>();
759+
if (!hasWarpMapping(parentForall)) {
760+
opsToTile.push_back(op);
733761
}
734762
}
735763
});

compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
8-
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
98
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
109
#include "iree/compiler/Codegen/Utils/Utils.h"
10+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1111
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1212
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1313
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -166,16 +166,16 @@ struct GPUReduceBankConflictsPass final
166166
void runOnOperation() override {
167167
FunctionOpInterface funcOp = getOperation();
168168

169-
// Skip bank conflict reduction if coalesced DMA ops are present.
169+
// Skip bank conflict reduction if gather_to_lds DMA ops are present.
170170
// DMA operations have their own optimized memory access patterns that
171171
// write directly to LDS with hardware-controlled coalescing. Padding
172172
// shared memory would interfere with the expected DMA memory layout.
173-
bool hasCoalescedDMA = false;
174-
funcOp.walk([&](IREE::GPU::CoalescedGatherDMAOp) {
175-
hasCoalescedDMA = true;
173+
bool hasGatherToLDS = false;
174+
funcOp.walk([&](amdgpu::GatherToLDSOp) {
175+
hasGatherToLDS = true;
176176
return WalkResult::interrupt();
177177
});
178-
if (hasCoalescedDMA) {
178+
if (hasGatherToLDS) {
179179
return;
180180
}
181181

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -940,19 +940,11 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
940940
{"subgroup", b.getI64ArrayAttr(subgroupTileSizes)},
941941
{"mma_kind", kind}};
942942

943-
// Use global load DMA attribute (subgroup sizes will be derived from
944-
// translation_info) only on gfx950+.
945-
SmallVector<Attribute> promotionArray;
946-
if (targetSupportsGlobalLoadDMA(target)) {
947-
Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
948-
promotionArray = {useGlobalDma, useGlobalDma};
949-
}
943+
// Build promotion list - global load DMA eligibility is determined
944+
// dynamically in the GPUConvertToCoalescedDMA pass based on whether
945+
// the source comes directly from global memory.
950946
SmallVector<int64_t> promotionList = {0, 1};
951947
if (scaled) {
952-
// TODO(#22119): We don't use global load DMA for scaled matmuls, because
953-
// compilation doesn't support it. Once this is fixed, we should use global
954-
// load DMA here when possible.
955-
promotionArray = {};
956948
promotionList.append({2, 3});
957949
}
958950
bool cWasPromoted = (!mustBeAligned || couldNeedPadding) && cPromoteIfPadding;
@@ -961,14 +953,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
961953
// and scaled GEMM respectively.
962954
promotionList.push_back(promotionList.size());
963955
}
964-
// Do not use direct load DMA when padding is needed, as the source will
965-
// go through tensor.pad and won't be directly from global memory. Also don't
966-
// use DMA types when C is promoted since C is output, not loaded from global.
967-
ArrayRef<Attribute> promotionTypes =
968-
(couldNeedPadding || cWasPromoted) ? ArrayRef<Attribute>{}
969-
: ArrayRef<Attribute>(promotionArray);
970956
GPU::appendPromotedOperandsList(context, attrs, promotionList,
971-
promotionTypes);
957+
/*promotionTypes=*/{});
972958
if (!mustBeAligned || couldNeedPadding) {
973959
SmallVector<int64_t> paddingTileSizes = workgroupTileSizes;
974960

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ module {
211211

212212
// CHECK-LABEL: func.func @data_tiled_scaled_mma_inner_tiled
213213
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
214-
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
214+
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>}
215215
// CHECK: iree_codegen.inner_tiled {{.*}}lowering_config = #iree_gpu.lowering_config
216216
// CHECK-SAME: promote_operands = [0, 1]
217217
// CHECK-SAME: reduction = [0, 0, 1, 1]

0 commit comments

Comments
 (0)