Skip to content

Commit bed7be1

Browse files
committed
[GPU] Skip subgroup-level tiling for tensor.pad fusion in coalesced DMA
The subgroup-level tiling was creating an outer loop (1, 4, 64) that distributed the padded buffer across multiple iterations, causing each iteration to create 1×64 dest subviews. The lowering pass would then use dest shape (1×64) for delinearization, causing all iterations to load from source row 0 instead of rows 0-3. This fix skips subgroup-level tiling for tensor.pad fusion cases by: 1. Detecting tensor.pad in applySubgroupTiling() before calling tileAtSubgroupLevel() 2. Adding a new ConvertPadFusionCopyToCoalescedDMA pattern that converts these operations directly without requiring warp-mapped forall parent This allows coalesced_gather_dma to operate on full 4×64 buffers with a single lane-mapped forall, letting the lowering pass correctly generate 4 transfers per lane to cover all source rows. Fixes unaligned matmul tests (65x64x121, 133x97x65).
1 parent 229911e commit bed7be1

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,63 @@ struct ConvertCopyToCoalescedDMA
475475
}
476476
};
477477

478+
/// Pattern to convert tensor.pad fusion cases directly without requiring
479+
/// warp-mapped forall parent.
480+
struct ConvertPadFusionCopyToCoalescedDMA
481+
: public OpRewritePattern<linalg::CopyOp> {
482+
using OpRewritePattern::OpRewritePattern;
483+
484+
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
485+
PatternRewriter &rewriter) const override {
486+
// Only match copies with use_global_load_dma config
487+
auto config = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(copyOp);
488+
if (!config) {
489+
return failure();
490+
}
491+
492+
// Check if this is a tensor.pad fusion case
493+
Value source = copyOp.getInputs()[0];
494+
// Trace through extract_slice to find tensor.pad
495+
while (auto extractSlice = source.getDefiningOp<tensor::ExtractSliceOp>()) {
496+
source = extractSlice.getSource();
497+
}
498+
auto pad = source.getDefiningOp<tensor::PadOp>();
499+
if (!pad) {
500+
return failure(); // Not a pad fusion case
501+
}
502+
503+
// Check if padding exists (non-zero low/high pad)
504+
bool hasPadding = false;
505+
for (auto [low, high] :
506+
llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) {
507+
if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) {
508+
hasPadding = true;
509+
break;
510+
}
511+
}
512+
if (!hasPadding) {
513+
return failure(); // No actual padding
514+
}
515+
516+
// This is a tensor.pad fusion case. Convert directly to
517+
// coalesced_gather_dma without requiring warp-mapped forall.
518+
auto outputType = cast<RankedTensorType>(copyOp.getOutputs()[0].getType());
519+
SmallVector<OpFoldResult> threadNumThreads =
520+
computeThreadNumThreadsImpl(rewriter, copyOp, outputType);
521+
if (threadNumThreads.empty()) {
522+
return failure();
523+
}
524+
525+
scf::ForallOp threadForallOp =
526+
tileToThreadLevel(copyOp, rewriter, threadNumThreads);
527+
if (!threadForallOp) {
528+
return failure();
529+
}
530+
531+
return createDMAInForall<linalg::CopyOp>(threadForallOp, rewriter);
532+
}
533+
};
534+
478535
struct ConvertGatherToCoalescedDMA
479536
: public OpRewritePattern<IREE::LinalgExt::GatherOp> {
480537
using OpRewritePattern<IREE::LinalgExt::GatherOp>::OpRewritePattern;
@@ -660,9 +717,11 @@ struct GPUConvertToCoalescedDMAPass final
660717
}
661718

662719
// Only tile and convert ops within forall ops with warp mapping.
720+
// Also handle tensor.pad fusion cases that don't have warp mapping.
663721
RewritePatternSet patterns(context);
664722
patterns.add<ConvertGatherToCoalescedDMA>(context);
665723
patterns.add<ConvertCopyToCoalescedDMA>(context);
724+
patterns.add<ConvertPadFusionCopyToCoalescedDMA>(context);
666725

667726
walkAndApplyPatterns(funcOp, std::move(patterns));
668727
}
@@ -854,6 +913,35 @@ struct GPUConvertToCoalescedDMAPass final
854913
// Apply subgroup-level tiling to each op.
855914
IRRewriter rewriter(context);
856915
for (Operation *op : opsToTile) {
916+
// Check if this is a tensor.pad fusion case for CopyOp.
917+
// If so, skip subgroup-level tiling to avoid creating the outer loop.
918+
bool skipSubgroupTiling = false;
919+
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
920+
Value source = copyOp.getInputs()[0];
921+
// Trace through extract_slice to find tensor.pad
922+
while (auto extractSlice =
923+
source.getDefiningOp<tensor::ExtractSliceOp>()) {
924+
source = extractSlice.getSource();
925+
}
926+
if (auto pad = source.getDefiningOp<tensor::PadOp>()) {
927+
// Check if padding exists (non-zero low/high pad)
928+
for (auto [low, high] :
929+
llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) {
930+
if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) {
931+
skipSubgroupTiling = true;
932+
break;
933+
}
934+
}
935+
}
936+
}
937+
938+
if (skipSubgroupTiling) {
939+
// Skip subgroup-level tiling for tensor.pad fusion.
940+
// The operation will be handled at thread-level tiling with full
941+
// buffers.
942+
continue;
943+
}
944+
857945
FailureOr<scf::SCFTilingResult> tilingResult =
858946
TypeSwitch<Operation *, FailureOr<scf::SCFTilingResult>>(op)
859947
.Case([&](linalg::CopyOp copyOp) {

0 commit comments

Comments
 (0)