diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp index aae1a70130ac..2c9be9d73d46 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp @@ -288,6 +288,58 @@ struct LowerCoalescedGatherDMAPattern final } SmallVector segments = std::move(*segmentsOrFailure); + // OOB padding requires fat_raw_buffer for hardware OOB clamping. + if (std::optional inBounds = dmaOp.getInBounds()) { + auto srcType = cast(source.getType()); + if (!hasAMDGPUFatRawBufferAddressSpace(srcType)) { + for (Attribute attr : *inBounds) { + if (!cast(attr).getValue()) { + return rewriter.notifyMatchFailure( + dmaOp, "in_bounds with OOB dimensions requires " + "fat_raw_buffer address space on source"); + } + } + } + + // For non-outermost dims with OOB (in_bounds=false), the vector read + // must not cross row boundaries. Each lane reads `elementsPerLane` + // contiguous elements from the source buffer. If the source dim size is + // not a multiple of elementsPerLane, a vector read near the end of a row + // will wrap into the next row instead of returning zeros. + // Example: source 64x62xf32, dest 64x64xf32, vector<4xf32>: + // Lane at [0, 60] reads 4 elements at flat offsets 60..63. + // Offset 62 wraps to [1, 0] instead of returning 0. + ArrayRef sourceShape = srcType.getShape(); + for (int64_t dim = 1; dim < srcType.getRank(); ++dim) { + if (dim >= static_cast(inBounds->size())) { + break; + } + bool dimInBounds = cast((*inBounds)[dim]).getValue(); + if (dimInBounds) { + continue; + } + // This non-outermost dim has padding. Check that source dim size is + // a multiple of elementsPerLane for every segment to prevent row + // crossing. + if (ShapedType::isDynamic(sourceShape[dim])) { + return rewriter.notifyMatchFailure( + dmaOp, "non-outermost OOB dim " + Twine(dim) + + " has dynamic source size; cannot verify vector " + "reads do not cross row boundaries"); + } + for (const TransferSegment &segment : segments) { + if (sourceShape[dim] % segment.elementsPerLane != 0) { + return rewriter.notifyMatchFailure( + dmaOp, "non-outermost OOB dim " + Twine(dim) + + " has source size " + Twine(sourceShape[dim]) + + " not divisible by elementsPerLane " + + Twine(segment.elementsPerLane) + + "; vector reads would cross row boundaries"); + } + } + } + } + // Set up for code generation. rewriter.setInsertionPoint(dmaOp); TypedValue laneId = dmaOp.getLane(); @@ -304,7 +356,8 @@ struct LowerCoalescedGatherDMAPattern final } emitTransfers(rewriter, loc, source, dest, destShape, numLinearDims, - elementType, indices, segments, segmentLaneOffsets); + elementType, indices, segments, segmentLaneOffsets, + dmaOp.getInBounds()); rewriter.eraseOp(dmaOp); return success(); @@ -337,7 +390,8 @@ struct LowerCoalescedGatherDMAPattern final Value dest, ArrayRef destShape, int64_t numLinearDims, Type elementType, OperandRange indices, ArrayRef segments, - ArrayRef segmentLaneOffsets) const { + ArrayRef segmentLaneOffsets, + std::optional inBoundsAttr) const { int64_t destRank = destShape.size(); int64_t numOuterDims = destRank - numLinearDims; LDBG() << "Emitting transfers: " << numOuterDims << " outer dims, " @@ -400,6 +454,55 @@ struct LowerCoalescedGatherDMAPattern final auto [srcIndices, dstIndices] = generateGatherIndices( rewriter, loc, srcDimOffsets, dstDimOffsets, indices); + // Raw buffer OOB clamping is 1D (linear): it returns 0 only when the + // byte offset >= total buffer size. For non-outermost dimensions, + // an OOB index wraps into the next row instead of returning 0. + // Fix: when any non-outermost source index exceeds its dimension, + // replace the outermost index with sourceShape[0] to force the + // linearized offset past the buffer end → hardware returns 0. + auto sourceType = cast(source.getType()); + if (inBoundsAttr && hasAMDGPUFatRawBufferAddressSpace(sourceType)) { + ArrayRef sourceShape = sourceType.getShape(); + Value anyNonOutermostOOB = arith::ConstantOp::create( + rewriter, loc, rewriter.getBoolAttr(false)); + + for (int64_t dim = 1; dim < sourceType.getRank(); ++dim) { + if (dim >= static_cast(inBoundsAttr->size())) { + break; + } + bool dimInBounds = + cast((*inBoundsAttr)[dim]).getValue(); + if (dimInBounds) { + continue; + } + + Value dimSize; + if (ShapedType::isDynamic(sourceShape[dim])) { + dimSize = memref::DimOp::create(rewriter, loc, source, dim); + } else { + dimSize = arith::ConstantIndexOp::create(rewriter, loc, + sourceShape[dim]); + } + + Value isOOB = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::uge, + srcIndices[dim], dimSize); + + anyNonOutermostOOB = arith::OrIOp::create( + rewriter, loc, anyNonOutermostOOB, isOOB); + } + + Value oobOuterIdx; + if (ShapedType::isDynamic(sourceShape[0])) { + oobOuterIdx = memref::DimOp::create(rewriter, loc, source, 0); + } else { + oobOuterIdx = + arith::ConstantIndexOp::create(rewriter, loc, sourceShape[0]); + } + srcIndices[0] = arith::SelectOp::create( + rewriter, loc, anyNonOutermostOOB, oobOuterIdx, srcIndices[0]); + } + amdgpu::GatherToLDSOp::create(rewriter, loc, source, srcIndices, dest, dstIndices, TypeAttr::get(transferType)); @@ -438,18 +541,20 @@ struct AMDGPULowerCoalescedDMAToGatherLDSPass final walkAndApplyPatterns(funcOp, std::move(patterns)); -#ifndef NDEBUG // Verify all CoalescedGatherDMAOps were lowered. Currently, we require all // ops to be successfully lowered. In the future, a fallback lowering path // (e.g., using global_load) could handle ops that don't match the pattern. WalkResult result = funcOp.walk([&](IREE::GPU::CoalescedGatherDMAOp op) { - op.emitOpError("failed to lower coalesced_gather_dma op"); + op.emitOpError( + "failed to lower to gather_to_lds; possible causes: source " + "lacks fat_raw_buffer address space for OOB padding, destination " + "is not contiguous, or element sizes are incompatible with " + "dma_sizes"); return WalkResult::interrupt(); }); if (result.wasInterrupted()) { return signalPassFailure(); } -#endif // NDEBUG } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp index bfa9a36950dc..cde1948449b8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp @@ -7,11 +7,13 @@ #include #include #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "llvm/Support/Debug.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -23,6 +25,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -58,6 +61,15 @@ static SmallVector getThreadMapping(MLIRContext *ctx) { return mapping; } +/// Trace through extract_slice operations to find an underlying tensor.pad. +/// Returns the PadOp if found, nullptr otherwise. +static tensor::PadOp traceToTensorPad(Value source) { + while (auto extractSlice = source.getDefiningOp()) { + source = extractSlice.getSource(); + } + return source.getDefiningOp(); +} + /// Check if a value traces back to tensor.empty (possibly through forall args). static bool tracesToTensorEmpty(Value value) { // Direct tensor.empty. @@ -97,6 +109,56 @@ static bool tracesToTensorEmpty(Value value) { return initValue.getDefiningOp() != nullptr; } +/// Check if the source of a copy traces to a fat_raw_buffer source. +/// Traces through extract_slice and pad ops to find the originating op. +/// Returns true if source is a block arg (opaque, allow DMA) or if it +/// traces to a LoadFromBufferOp with fat_raw_buffer address space. +/// Returns false if source traces to a LoadFromBufferOp without +/// fat_raw_buffer, or to any other concrete op (e.g. dispatch.tensor.load). +static bool sourceIsFromFatRawBuffer(Value source) { + // Trace through extract_slice and pad ops. + while (true) { + if (auto extractSlice = source.getDefiningOp()) { + source = extractSlice.getSource(); + continue; + } + if (auto pad = source.getDefiningOp()) { + source = pad.getSource(); + continue; + } + break; + } + + // Block args are opaque; conservatively allow DMA. + if (isa(source)) { + return true; + } + + // Check if source comes from a LoadFromBufferOp with fat_raw_buffer. + auto loadOp = source.getDefiningOp(); + if (!loadOp) { + return false; + } + + auto memrefType = cast(loadOp.getBuffer().getType()); + return hasAMDGPUFatRawBufferAddressSpace(memrefType); +} + +/// Check if the target architecture supports global load DMA. +/// Returns true only for CDNA4+ (gfx950+) architectures. +static bool targetSupportsGlobalLoadDMA(IREE::GPU::TargetAttr target) { + if (!target) { + return false; + } + FailureOr chipset = amdgpu::Chipset::parse(target.getArch()); + if (failed(chipset)) { + return false; + } + // CDNA4 is gfx950+ (major=9, minor>=5). Other major versions (RDNA, etc.) + // do not support global load DMA. + return chipset->majorVersion == 9 && chipset->minorVersion >= 5; +} + /// Helper to compute thread number of threads based on translation_info. /// Uses the subgroup_size from translation_info for thread-level tiling. static SmallVector @@ -135,7 +197,7 @@ computeThreadNumThreadsImpl(OpBuilder &builder, Operation *op, // Get DMA sizes from target to compute minimum transfer size. IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) { + if (!target || !targetSupportsGlobalLoadDMA(target)) { return {}; } @@ -300,14 +362,77 @@ static LogicalResult createDMAInForall(scf::ForallOp threadForallOp, Location loc = innerOp.getLoc(); Value source, indices; + SmallVector inBoundsVec; // Extract source and indices based on op type. if constexpr (std::is_same_v) { Value input = innerOp.getInputs()[0]; - if (auto extractSlice = input.getDefiningOp()) { - source = extractSlice.getSource(); - } else { - return failure(); + + // After tiling, the input is typically: + // tensor.extract_slice %padded[...] [...] [1, 1] + // We need to trace through extract_slice to find if source is tensor.pad. + if (tensor::PadOp pad = traceToTensorPad(input)) { + // Verify pad constraints: low padding must be all zeros, pad value must + // be 0. + // TODO(#23365): Support non-zero pad values (e.g., -inf, 1) by emitting + // a select on the loaded values from LDS to replace OOB zeros with the + // desired padding element. + bool validPad = true; + for (OpFoldResult low : pad.getMixedLowPad()) { + if (!isConstantIntValue(low, 0)) { + validPad = false; + break; + } + } + Value padVal = pad.getConstantPaddingValue(); + if (!padVal || !(matchPattern(padVal, m_AnyZeroFloat()) || + matchPattern(padVal, m_Zero()))) { + validPad = false; + } + + if (validPad) { + // Use pad.getSource() directly as the DMA source. + // This is the tensor.extract_slice result (e.g., tensor). + source = pad.getSource(); + + // Check if source tensor's innermost row size is DWORD (4-byte) + // aligned. On AMD CDNA, per-component range checking is performed for + // each DWORD. If a DWORD is partially out-of-bounds, the entire DWORD + // returns zero, causing incorrect results. Additionally, partial OOB + // triggers the slow path with multi-cycling and instruction issue + // penalties. + auto sourceType = cast(source.getType()); + int64_t innermostDim = sourceType.getShape().back(); + if (!ShapedType::isDynamic(innermostDim)) { + Type elemType = sourceType.getElementType(); + int64_t elemBytes = elemType.getIntOrFloatBitWidth() / 8; + int64_t rowBytes = innermostDim * elemBytes; + if (rowBytes % 4 != 0) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping DMA: row size " << rowBytes + << " bytes not DWORD-aligned (slow path)\n"); + return failure(); + } + } + + // Compute in_bounds based on whether padding was added per dimension. + for (auto [low, high] : + llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) { + bool isInBounds = + isConstantIntValue(low, 0) && isConstantIntValue(high, 0); + inBoundsVec.push_back(isInBounds); + } + } + } + + // Fallback: no tensor.pad fusion. The input is an extract_slice from + // tiling; trace through it to get the actual source. + if (!source) { + if (auto extractSlice = input.getDefiningOp()) { + source = extractSlice.getSource(); + } else { + return failure(); + } } } else if constexpr (std::is_same_v) { source = innerOp.getSource(); @@ -356,15 +481,22 @@ static LogicalResult createDMAInForall(scf::ForallOp threadForallOp, // Create the DMA op in the in_parallel region. rewriter.setInsertionPointToStart(&inParallelBlock); - SmallVector indicesVec; + SmallVector indicesOperands; if (indices) { - indicesVec.push_back(indices); + indicesOperands.push_back(indices); + } + + // Create in_bounds attribute if we fused a tensor.pad. + ArrayAttr inBoundsAttr; + if (!inBoundsVec.empty()) { + inBoundsAttr = rewriter.getBoolArrayAttr(inBoundsVec); } // When used in forall.in_parallel, the op doesn't return a result // as it performs an in-place update to the shared_outs tensor. IREE::GPU::CoalescedGatherDMAOp::create(rewriter, loc, Type(), source, - indicesVec, sharedOut, laneId); + indicesOperands, sharedOut, laneId, + inBoundsAttr); // Erase the parallel_insert_slice ops and inner operation. for (tensor::ParallelInsertSliceOp &insertOp : toErase) { @@ -416,11 +548,71 @@ struct ConvertCopyToCoalescedDMA SmallVector computeThreadNumThreads(OpBuilder &builder, linalg::CopyOp copyOp) const override { + if (!sourceIsFromFatRawBuffer(copyOp.getInputs()[0])) { + return {}; + } auto outputType = cast(copyOp.getOutputs()[0].getType()); return computeThreadNumThreadsImpl(builder, copyOp, outputType); } }; +/// Pattern to convert tensor.pad fusion cases directly without requiring +/// warp-mapped forall parent. +struct ConvertPadFusionCopyToCoalescedDMA + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::CopyOp copyOp, + PatternRewriter &rewriter) const override { + // Only match copies with use_global_load_dma config. + auto config = getLoweringConfig(copyOp); + if (!config) { + return failure(); + } + + // Skip if source is not from fat_raw_buffer. + if (!sourceIsFromFatRawBuffer(copyOp.getInputs()[0])) { + return failure(); + } + + // Check if this is a tensor.pad fusion case. + tensor::PadOp pad = traceToTensorPad(copyOp.getInputs()[0]); + if (!pad) { + return failure(); // Not a pad fusion case + } + + // Check if padding exists (non-zero low/high pad). + bool hasPadding = false; + for (auto [low, high] : + llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) { + if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) { + hasPadding = true; + break; + } + } + if (!hasPadding) { + return failure(); // No actual padding + } + + // This is a tensor.pad fusion case. Convert directly to + // coalesced_gather_dma without requiring warp-mapped forall. + auto outputType = cast(copyOp.getOutputs()[0].getType()); + SmallVector threadNumThreads = + computeThreadNumThreadsImpl(rewriter, copyOp, outputType); + if (threadNumThreads.empty()) { + return failure(); + } + + scf::ForallOp threadForallOp = + tileToThreadLevel(copyOp, rewriter, threadNumThreads); + if (!threadForallOp) { + return failure(); + } + + return createDMAInForall(threadForallOp, rewriter); + } +}; + struct ConvertGatherToCoalescedDMA : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -465,7 +657,7 @@ struct ConvertGatherToCoalescedDMA int64_t elementBits = elementType.getIntOrFloatBitWidth(); IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) { + if (!target || !targetSupportsGlobalLoadDMA(target)) { return failure(); } @@ -574,7 +766,8 @@ struct ConvertGatherToCoalescedDMA rewriter.setInsertionPointToStart(&inParallelBlock); IREE::GPU::CoalescedGatherDMAOp::create(rewriter, loc, Type(), source, - indicesVec, sharedOut, laneId); + indicesVec, sharedOut, laneId, + /*in_bounds=*/nullptr); // Erase parallel_insert_slice ops and gather op. SmallVector toErase; @@ -605,9 +798,11 @@ struct GPUConvertToCoalescedDMAPass final } // Only tile and convert ops within forall ops with warp mapping. + // Also handle tensor.pad fusion cases that don't have warp mapping. RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); + patterns.add(context); walkAndApplyPatterns(funcOp, std::move(patterns)); } @@ -758,9 +953,44 @@ struct GPUConvertToCoalescedDMAPass final return failure(); } - // Compute tile sizes for subgroup-level distribution. - auto [tileSizes, numTiledDims] = - computeSubgroupTileSizes(rewriter, shape, numWarps); + // Check if this is a tensor.pad fusion case. + bool isPadFusion = false; + if (auto copyOp = dyn_cast(op.getOperation())) { + if (tensor::PadOp pad = traceToTensorPad(copyOp.getInputs()[0])) { + // Check if padding exists (non-zero low/high pad). + for (auto [low, high] : + llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) { + if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) { + isPadFusion = true; + break; + } + } + } + } + + SmallVector tileSizes; + int64_t numTiledDims = 0; + + if (isPadFusion) { + // TODO(#23365): Tile to subgroups for pad fusion by propagating source + // offsets through tiling. Currently, after subgroup tiling each warp's + // DMA gets the full pre-pad source but a sub-tiled init, and the DMA + // lowering has no way to offset into the source. This requires adding + // source offset support to CoalescedGatherDMAOp. For now, create a + // single-iteration wrapper forall so the DMA sees the full buffer. + // Bail out if any dimension is dynamic since we need static tile sizes. + if (llvm::any_of(shape, ShapedType::isDynamic)) { + return failure(); + } + for (int64_t i = 0; i < rank; ++i) { + tileSizes.push_back(rewriter.getIndexAttr(shape[i])); + ++numTiledDims; + } + } else { + // Compute tile sizes for subgroup-level distribution. + std::tie(tileSizes, numTiledDims) = + computeSubgroupTileSizes(rewriter, shape, numWarps); + } if (numTiledDims == 0) { return failure(); @@ -780,13 +1010,28 @@ struct GPUConvertToCoalescedDMAPass final } LogicalResult applySubgroupTiling(FunctionOpInterface funcOp) { + // Check if the target supports global load DMA (gfx950+). + IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); + if (!targetSupportsGlobalLoadDMA(target)) { + return success(); + } + MLIRContext *context = &getContext(); SmallVector opsToTile; // Collect all ops with iree_gpu.use_global_load_dma lowering config. // Skip ops that are already inside a warp-mapped forall. funcOp->walk([&](Operation *op) { - if (isa(op)) { + if (auto copyOp = dyn_cast(op)) { + auto config = getLoweringConfig(op); + if (!config || !sourceIsFromFatRawBuffer(copyOp.getInputs()[0])) { + return; + } + auto parentForall = op->getParentOfType(); + if (!hasWarpMapping(parentForall)) { + opsToTile.push_back(op); + } + } else if (isa(op)) { auto config = getLoweringConfig(op); if (config) { auto parentForall = op->getParentOfType(); @@ -798,6 +1043,9 @@ struct GPUConvertToCoalescedDMAPass final }); // Apply subgroup-level tiling to each op. + // For tensor.pad fusion cases, tileAtSubgroupLevel creates a + // single-iteration wrapper forall to maintain the expected structure while + // allowing the DMA to operate on the full buffer. IRRewriter rewriter(context); for (Operation *op : opsToTile) { FailureOr tilingResult = diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp index 352ff94ade0e..e687891e7c96 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp @@ -47,7 +47,8 @@ bool isDefinitelyShared(bufferization::AllocTensorOp alloc) { auto forallOp = dyn_cast(user); if (!forallOp || !forallOpHasMappingType(forallOp)) { + gpu::GPUWarpMappingAttr, IREE::GPU::LaneIdAttr>( + forallOp)) { return false; } } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir index 5c7929f8ca5f..3f3979f0a50b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir @@ -1025,3 +1025,319 @@ func.func @lower_coalesced_dma_lane_offset_regression( } {mapping = [#gpu.thread]} return } + +// ----- + +// Test: coalesced_gather_dma with in_bounds attribute (tensor.pad fusion case). +// When in_bounds = [false, true], the source dim 0 can differ from dest dim 0. +// This happens when tensor.pad is fused - source is the pre-padded tensor, +// dest is the padded shape. Hardware OOB returns 0 for reads beyond source bounds. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", + "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target< + arch = "gfx950", features = "", wgp = < + compute = fp32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, simds_per_wgp = 4, + vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}> + +#translation_32 = #iree_codegen.translation_info + +// CHECK-LABEL: func.func @lower_coalesced_dma_with_in_bounds +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<2x128xf32, #amdgpu.address_space> +// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<4x128xf32, #gpu.address_space> +func.func @lower_coalesced_dma_with_in_bounds( + %source: memref<2x128xf32, #amdgpu.address_space>, + %dest: memref<4x128xf32, #gpu.address_space>) + attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb, + translation_info = #translation_32} { + // Source is 2x128 (pre-padded), dest is 4x128 (padded). + // in_bounds = [false, true]: dim 0 may OOB (padding), dim 1 is in-bounds. + // Lowering uses dest shape (4x128) to compute transfer pattern. + // Reads beyond source row 2 will return 0 via hardware OOB. + // + // Since only the outermost dim (dim 0) is OOB, no non-outermost bounds check + // is needed. The identity select (select false, oobIdx, srcIdx) is emitted + // and will be folded away by canonicalization. + // + // CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (32) + scf.forall (%arg6) in (32) { + // 4 rows * 128 cols = 512 elements total, 4 elements per lane = 4 transfers + // CHECK: %[[C4:[a-zA-Z0-9_]+]] = arith.constant 4 + // CHECK: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C4]] + // + // Transfer 1: linearOffset = 0 + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]] + // CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SRC_LIN0]] into (4, 128) + // CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (4, 128) + // No non-outermost OOB dims, so select is identity (false → original index). + // CHECK: %[[FALSE0:.+]] = arith.constant false + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[FIXED0:.+]] = arith.select %[[FALSE0]], %[[C2]], %[[SRC_DELIN0]]#0 + // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED0]], %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<4xf32> + // + // Transfers 2-4: same pattern for remaining rows + // CHECK-COUNT-3: amdgpu.gather_to_lds {{.+}} : vector<4xf32> + // CHECK-NOT: amdgpu.gather_to_lds + // CHECK-NOT: iree_gpu.coalesced_gather_dma + iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [false, true] : + memref<2x128xf32, #amdgpu.address_space>, + memref<4x128xf32, #gpu.address_space>, index + } {mapping = [#gpu.thread]} + return +} + +// ----- + +// Test: coalesced_gather_dma with in_bounds for unaligned matmul tensor.pad fusion. +// This tests the exact pattern from unaligned matmul (65x64x121): +// - RHS slice shape: 4x64 (K-tile x N-dim) +// - 64 lanes (one subgroup) +// - in_bounds = [false, true]: K-dim may OOB (last tile 121 % 4 = 1), N-dim is aligned +// +// With 64 lanes, 4x64 dest shape, and dma_sizes = [32, 128]: +// - Elements per lane = 256 / 64 = 4 (each lane reads 4xf32 = 128 bits) +// - Delinearization basis = (4, 64) +// - 1 transfer covers all 256 elements + +#executable_target_rocm_hsaco_fb_unaligned = #hal.executable.target<"rocm", + "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target< + arch = "gfx950", features = "", wgp = < + compute = fp32, storage = b32, subgroup = shuffle, dot = none, mma = [], + subgroup_size_choices = [64, 64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, simds_per_wgp = 4, + vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}> + +#translation_64_unaligned = #iree_codegen.translation_info + +// CHECK-LABEL: func.func @lower_coalesced_dma_4x64_tensor_pad_fusion +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref, #amdgpu.address_space> +// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<4x64xf32, #gpu.address_space> +func.func @lower_coalesced_dma_4x64_tensor_pad_fusion( + %source: memref, #amdgpu.address_space>, + %dest: memref<4x64xf32, #gpu.address_space>) + attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb_unaligned, + translation_info = #translation_64_unaligned} { + // CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (64) + scf.forall (%arg6) in (64) { + // Each lane reads 4 elements (256 elements / 64 lanes = 4). + // CHECK: %[[C4:[a-zA-Z0-9_]+]] = arith.constant 4 : index + // CHECK: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C4]] + // + // 1 transfer with delinearization basis (4, 64): + // Transfer 1: linearOffset = 0 + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]] + // CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SRC_LIN0]] into (4, 64) + // CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (4, 64) + // in_bounds = [false, true]: no non-outermost OOB dims, select is identity. + // CHECK: %[[FALSE0:.+]] = arith.constant false + // CHECK: %[[DIM0:.+]] = memref.dim %[[SRC]], %{{.+}} + // CHECK: %[[FIXED0:.+]] = arith.select %[[FALSE0]], %[[DIM0]], %[[SRC_DELIN0]]#0 + // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED0]], %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<4xf32> + // CHECK-NOT: amdgpu.gather_to_lds + // CHECK-NOT: iree_gpu.coalesced_gather_dma + iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [false, true] : + memref, #amdgpu.address_space>, + memref<4x64xf32, #gpu.address_space>, index + } {mapping = [#gpu.thread]} + return +} + +// ----- + +// Test: Non-outermost dimension padding with in_bounds = [false, false]. +// Source: 4x6, dest: 4x8. Dim 1 has padding (6 → 8). +// Raw buffer OOB is linear/1D, so for non-outermost dim OOB, we must +// replace the outermost index with sourceShape[0] to force hardware OOB. +// +// Without the fix: reading at [0, 6] computes a byte offset within the +// buffer and wraps to [1, 0] instead of returning 0. +// With the fix: when srcIndices[1] >= 6, srcIndices[0] is replaced with 4 +// (source dim 0 size), guaranteeing linear offset >= buffer size → returns 0. + +#executable_target_rocm_hsaco_fb_pad = #hal.executable.target<"rocm", + "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target< + arch = "gfx950", features = "", wgp = < + compute = fp32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, simds_per_wgp = 4, + vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}> + +#translation_32_pad = #iree_codegen.translation_info + +// CHECK-LABEL: func.func @gather_dma_non_outermost_oob_check +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<4x6xf32, #amdgpu.address_space> +// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<4x8xf32, #gpu.address_space> +func.func @gather_dma_non_outermost_oob_check( + %source: memref<4x6xf32, #amdgpu.address_space>, + %dest: memref<4x8xf32, #gpu.address_space>) + attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb_pad, + translation_info = #translation_32_pad} { + // CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (32) + scf.forall (%arg6) in (32) { + // CHECK: %[[C1:[a-zA-Z0-9_]+]] = arith.constant 1 + // CHECK: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C1]] + // + // Transfer 1: linearOffset = 0 + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]] + // CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SRC_LIN0]] into (4, 8) + // CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (4, 8) + // + // Bounds check: compare srcIndices[1] >= 6 (source dim 1 size) + // CHECK: %[[FALSE:.+]] = arith.constant false + // CHECK: %[[C6:.+]] = arith.constant 6 : index + // CHECK: %[[CMP:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C6]] : index + // CHECK: %[[OOB:.+]] = arith.ori %[[FALSE]], %[[CMP]] : i1 + // Replace outermost index with 4 (source dim 0 size) to force hardware OOB + // CHECK: %[[C4_OOB:.+]] = arith.constant 4 : index + // CHECK: %[[FIXED_IDX:.+]] = arith.select %[[OOB]], %[[C4_OOB]], %[[SRC_DELIN0]]#0 : index + // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED_IDX]], %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<1xf32> + // CHECK-NOT: iree_gpu.coalesced_gather_dma + iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [false, false] : + memref<4x6xf32, #amdgpu.address_space>, + memref<4x8xf32, #gpu.address_space>, index + } {mapping = [#gpu.thread]} + return +} + +// ----- + +// Test: Inner-dim padding where source dim is aligned with vector width. +// Source 64x60 padded to 64x64, in_bounds = [true, false]. +// elementsPerLane = 4 (128-bit DMA with f32), 60 % 4 = 0, so no vector +// read crosses a row boundary. The OOB check correctly handles dim 1. + +#executable_target_rocm_hsaco_fb_inner_pad = #hal.executable.target<"rocm", + "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target< + arch = "gfx950", features = "", wgp = < + compute = fp32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, simds_per_wgp = 4, + vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}> + +#translation_64_inner_pad = #iree_codegen.translation_info + +// CHECK-LABEL: func.func @gather_dma_inner_dim_oob_64x60 +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<64x60xf32, #amdgpu.address_space> +// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<64x64xf32, #gpu.address_space> +func.func @gather_dma_inner_dim_oob_64x60( + %source: memref<64x60xf32, #amdgpu.address_space>, + %dest: memref<64x64xf32, #gpu.address_space>) + attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb_inner_pad, + translation_info = #translation_64_inner_pad} { + // CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (64) + scf.forall (%arg6) in (64) { + // Each lane transfers vector<4xf32> (dma_sizes [128] = 128 bits = 4 x f32). + // CHECK: %[[C4:[a-zA-Z0-9_]+]] = arith.constant 4 + // CHECK: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C4]] + // + // Transfer 1: linearOffset = 0 + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]] + // CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SRC_LIN0]] into (64, 64) + // CHECK: %[[DST_DELIN0:.+]]:2 = affine.delinearize_index %[[C0]] into (64, 64) + // + // Bounds check: compare srcIndices[1] >= 60 (source inner dim size). + // CHECK: %[[FALSE:.+]] = arith.constant false + // CHECK: %[[C60:.+]] = arith.constant 60 : index + // CHECK: %[[CMP:.+]] = arith.cmpi uge, %[[SRC_DELIN0]]#1, %[[C60]] : index + // CHECK: %[[OOB:.+]] = arith.ori %[[FALSE]], %[[CMP]] : i1 + // Replace outermost index with 64 (source dim 0 size) to force hardware OOB. + // CHECK: %[[C64_OOB:.+]] = arith.constant 64 : index + // CHECK: %[[FIXED_IDX:.+]] = arith.select %[[OOB]], %[[C64_OOB]], %[[SRC_DELIN0]]#0 : index + // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[FIXED_IDX]], %[[SRC_DELIN0]]#1], %[[DST]][%[[DST_DELIN0]]#0, %[[DST_DELIN0]]#1] : vector<4xf32> + // CHECK-NOT: iree_gpu.coalesced_gather_dma + iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [true, false] : + memref<64x60xf32, #amdgpu.address_space>, + memref<64x64xf32, #gpu.address_space>, index + } {mapping = [#gpu.thread]} + return +} + +// ----- + +// Test: Inner-dim padding rejected when source dim is not aligned with vector width. +// Source 64x62 padded to 64x64, in_bounds = [true, false]. +// elementsPerLane = 4 (128-bit DMA with f32), 62 % 4 != 0. +// A vector read at [0, 60] would span elements 60..63 in the flat buffer, +// wrapping into the next row (offset 62 = [1, 0]) instead of returning 0. + +#executable_target_rocm_hsaco_fb_inner_pad_unaligned = #hal.executable.target<"rocm", + "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target< + arch = "gfx950", features = "", wgp = < + compute = fp32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, simds_per_wgp = 4, + vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}> + +#translation_64_inner_pad_unaligned = #iree_codegen.translation_info + +func.func @gather_dma_inner_dim_oob_64x62_rejected( + %source: memref<64x62xf32, #amdgpu.address_space>, + %dest: memref<64x64xf32, #gpu.address_space>) + attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb_inner_pad_unaligned, + translation_info = #translation_64_inner_pad_unaligned} { + scf.forall (%arg6) in (64) { + // expected-error @+1 {{failed to lower to gather_to_lds; possible causes: source lacks fat_raw_buffer address space for OOB padding, destination is not contiguous, or element sizes are incompatible with dma_sizes}} + iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [true, false] : + memref<64x62xf32, #amdgpu.address_space>, + memref<64x64xf32, #gpu.address_space>, index + } {mapping = [#gpu.thread]} + return +} + +// ----- + +// Test: in_bounds with OOB dimensions on non-fat_raw_buffer source should +// not be lowered (pattern fails because hardware OOB clamping is unavailable). + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", + "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target< + arch = "gfx950", features = "", wgp = < + compute = fp32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, simds_per_wgp = 4, + vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}> + +#translation_64 = #iree_codegen.translation_info + +func.func @no_lower_oob_without_fat_raw_buffer( + %source: memref<2x128xf32>, + %dest: memref<4x128xf32, #gpu.address_space>) + attributes {hal.executable.target = #executable_target_rocm_hsaco_fb, + translation_info = #translation_64} { + scf.forall (%arg6) in (64) { + // expected-error @+1 {{failed to lower to gather_to_lds; possible causes: source lacks fat_raw_buffer address space for OOB padding, destination is not contiguous, or element sizes are incompatible with dma_sizes}} + iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) in_bounds [false, true] : + memref<2x128xf32>, + memref<4x128xf32, #gpu.address_space>, index + } {mapping = [#gpu.thread]} + return +} diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir index f5210ea7e7be..5f7a4938b285 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir @@ -1,6 +1,6 @@ // RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-convert-to-coalesced-dma,canonicalize))" %s --split-input-file | FileCheck %s -#gpu_target_copy = #iree_gpu.target, %init: tensor<64x512xf32>) -> tenso // With 16 warps (128*512/64/64) and 64 rows: step = ceil(64/16) = 4 rows, 512 cols (whole) // CHECK: %[[WARP_RESULT:.+]] = scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (64, 512) step (4, 512) // CHECK-SAME: shared_outs(%[[INIT_TILE:.+]] = %[[INIT]]) -> (tensor<64x512xf32>) { - // CHECK: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SRC]][%[[IV0]], 0] [4, 512] [1, 1] - // CHECK-SAME: : tensor<64x512xf32> to tensor<4x512xf32> - // CHECK: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [4, 512] [1, 1] - // CHECK-SAME: : tensor<64x512xf32> to tensor<4x512xf32> + // CHECK-DAG: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SRC]][%[[IV0]], 0] [4, 512] [1, 1] + // CHECK-DAG: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [4, 512] [1, 1] // Thread-level forall: // CHECK: %[[THREAD_RESULT:.+]] = scf.forall (%[[LANE:.+]]) in (64) @@ -36,14 +34,12 @@ func.func @copy(%source: tensor<64x512xf32>, %init: tensor<64x512xf32>) -> tenso // CHECK: iree_gpu.coalesced_gather_dma %[[SLICE_SRC]] into %[[THREAD_INIT]] lane(%[[LANE]]) // CHECK-SAME: : tensor<4x512xf32>, tensor<4x512xf32>, index // CHECK: } - // CHECK: } {mapping = [#iree_gpu.lane_id<0>]} // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[THREAD_RESULT]] into %[[INIT_TILE]][%[[IV0]], 0] [4, 512] [1, 1] - // CHECK-SAME: : tensor<4x512xf32> into tensor<64x512xf32> // CHECK: } - // CHECK: } + // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} // CHECK: return %[[WARP_RESULT]] // CHECK-NOT: linalg.copy @@ -53,7 +49,7 @@ func.func @copy(%source: tensor<64x512xf32>, %init: tensor<64x512xf32>) -> tenso // ----- -#gpu_target_gather = #iree_gpu.target, %indices: tensor<64xi32>, %init: // With 64 warps and 64 rows: step = ceil(64/64) = 1 row, 512 cols (whole) // CHECK: %[[WARP_RESULT:.+]] = scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (64, 512) step (1, 512) // CHECK-SAME: shared_outs(%[[INIT_TILE:.+]] = %[[INIT]]) -> (tensor<64x512xf32>) { - // CHECK: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [1, 512] [1, 1] - // CHECK-SAME: : tensor<64x512xf32> to tensor<1x512xf32> - // CHECK: %[[SLICE_INDICES:.+]] = tensor.extract_slice %[[INDICES]][%[[IV0]]] [1] [1] - // CHECK-SAME: : tensor<64xi32> to tensor<1xi32> + // CHECK-DAG: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [1, 512] [1, 1] + // CHECK-DAG: %[[SLICE_INDICES:.+]] = tensor.extract_slice %[[INDICES]][%[[IV0]]] [1] [1] // Thread-level forall: // CHECK: %[[THREAD_RESULT:.+]] = scf.forall (%[[LANE:.+]]) in (64) @@ -95,9 +89,8 @@ func.func @gather(%source: tensor<64x512xf32>, %indices: tensor<64xi32>, %init: // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[THREAD_RESULT]] into %[[INIT_TILE]][%[[IV0]], 0] [1, 512] [1, 1] - // CHECK-SAME: : tensor<1x512xf32> into tensor<64x512xf32> // CHECK: } - // CHECK: } + // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} // CHECK: return %[[WARP_RESULT]] // CHECK-NOT: iree_linalg_ext.gather @@ -110,7 +103,7 @@ func.func @gather(%source: tensor<64x512xf32>, %indices: tensor<64xi32>, %init: // Negative test: Skip coalesced DMA when innermost dimension < subgroup size. This is to ensure we do not go down // the slow path (which is not implemented yet). -#gpu_target_small_inner = #iree_gpu.target, %init: te // CHECK-SAME: shared_outs(%[[INIT_TILE:.+]] = %[[INIT]]) -> (tensor<64x128xf32>) { // Key check: subviews are 16x128 (contiguous) not 64x64 (non-contiguous) - // CHECK: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SRC]][%[[IV0]], 0] [16, 128] [1, 1] + // CHECK-DAG: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SRC]][%[[IV0]], 0] [16, 128] [1, 1] // CHECK-SAME: : tensor<64x128xf32> to tensor<16x128xf32> - // CHECK: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [16, 128] [1, 1] + // CHECK-DAG: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [16, 128] [1, 1] // CHECK-SAME: : tensor<64x128xf32> to tensor<16x128xf32> // Thread-level forall distributes across lanes: @@ -229,7 +222,7 @@ func.func @copy_prefer_contiguous_subview(%source: tensor<64x128xf32>, %init: te // CHECK: tensor.parallel_insert_slice %[[THREAD_RESULT]] into %[[INIT_TILE]][%[[IV0]], 0] [16, 128] [1, 1] // CHECK-SAME: : tensor<16x128xf32> into tensor<64x128xf32> // CHECK: } - // CHECK: } + // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} // CHECK: return %[[WARP_RESULT]] // CHECK-NOT: linalg.copy @@ -243,7 +236,7 @@ func.func @copy_prefer_contiguous_subview(%source: tensor<64x128xf32>, %init: te // When output comes from tensor.empty(), we can use total elements instead of // innermost dimension for the size check, enabling coalesced DMA. -#gpu_target_linearize = #iree_gpu.target) -> tenso // Warp-level forall: step (32, 16) distributes 128 rows across 4 warps // CHECK: %[[WARP_RESULT:.+]] = scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (128, 16) step (32, 16) // CHECK-SAME: shared_outs(%[[INIT_TILE:.+]] = %[[EMPTY]]) -> (tensor<128x16xf32>) { - // CHECK: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SRC]][%[[IV0]], 0] [32, 16] [1, 1] + // CHECK-DAG: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SRC]][%[[IV0]], 0] [32, 16] [1, 1] // CHECK-SAME: : tensor<128x16xf32> to tensor<32x16xf32> - // CHECK: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [32, 16] [1, 1] + // CHECK-DAG: %[[SLICE_DST:.+]] = tensor.extract_slice %[[INIT_TILE]][%[[IV0]], 0] [32, 16] [1, 1] // CHECK-SAME: : tensor<128x16xf32> to tensor<32x16xf32> // Thread-level forall with 64 lanes @@ -290,7 +283,7 @@ func.func @copy_small_innermost_linearized(%source: tensor<128x16xf32>) -> tenso // CHECK: tensor.parallel_insert_slice %[[THREAD_RESULT]] into %[[INIT_TILE]][%[[IV0]], 0] [32, 16] [1, 1] // CHECK-SAME: : tensor<32x16xf32> into tensor<128x16xf32> // CHECK: } - // CHECK: } + // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} // CHECK: return %[[WARP_RESULT]] // CHECK-NOT: linalg.copy @@ -303,7 +296,7 @@ func.func @copy_small_innermost_linearized(%source: tensor<128x16xf32>) -> tenso // Test: 1D tensor copy distributes warps across the single dimension. // This tests the 1D tile size computation logic for flattened copies. -#gpu_target_1d = #iree_gpu.target) -> tensor<2048xf32> // 1. Innermost dim (16) < minElementsPerTransfer (64) // 2. Output is a function argument, not tensor.empty, so we can't linearize -#gpu_target_no_linearize = #iree_gpu.target, %dest: // The copy should be converted to coalesced DMA when the input comes from an // extract_slice with contiguous innermost dimensions. -#gpu_target_extract_input = #iree_gpu.target) -> return %result : tensor<64x128xf32> } + +// ----- + +// Test: tensor.pad fusion into coalesced_gather_dma. +// When linalg.copy reads from tensor.pad, trace through to the original source +// and set in_bounds attribute based on padding. + +#gpu_target_pad = #iree_gpu.target> + +#exec_target_pad = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_pad}> +#translation_pad = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_with_tensor_pad_fusion +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<121x64xf32> +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4x64xf32> +func.func @copy_with_tensor_pad_fusion(%source: tensor<121x64xf32>, %init: tensor<4x64xf32>, %off: index, %sz: index, %high: index) -> tensor<4x64xf32> + attributes {hal.executable.target = #exec_target_pad, translation_info = #translation_pad} { + // Extract a dynamic slice. + %extracted = tensor.extract_slice %source[%off, 0] [%sz, 64] [1, 1] + : tensor<121x64xf32> to tensor + + // Pad to static size (only M dimension has padding). + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %extracted low[0, 0] high[%high, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor to tensor<4x64xf32> + + // Copy from padded tensor. + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%padded : tensor<4x64xf32>) + outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32> + + // Key check: tensor.pad is fused - source is the extract_slice result, not the padded tensor. + // in_bounds = [false, true] because M dim has dynamic padding, K dim has no padding. + // CHECK: %[[EXTRACTED:.+]] = tensor.extract_slice %[[SRC]] + // CHECK: scf.forall {{.*}} shared_outs(%[[OUTER_INIT:.+]] = %[[INIT]]) + // CHECK: scf.forall (%[[LANE:.+]]) in (64) shared_outs(%[[INNER_INIT:.+]] = %[[OUTER_INIT]]) + // CHECK: scf.forall.in_parallel { + // CHECK: iree_gpu.coalesced_gather_dma %[[EXTRACTED]] into %[[INNER_INIT]] lane(%[[LANE]]) in_bounds [false, true] + // CHECK-SAME: : tensor, tensor<4x64xf32>, index + // CHECK: } + // CHECK-NOT: tensor.pad + + return %result : tensor<4x64xf32> +} + +// ----- + +// Test: tensor.pad fusion with multiple warps creates single-iteration wrapper forall. +// When tensor.pad is fused, subgroup-level tiling is skipped to ensure the DMA +// operates on the full padded buffer shape, not on smaller subviews. +// This is critical for correct delinearization in the lowering pass. + +#gpu_target_pad_multi_warp = #iree_gpu.target> + +#exec_target_pad_multi_warp = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_pad_multi_warp}> +#translation_pad_multi_warp = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_with_tensor_pad_fusion_multi_warp +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<121x64xf32> +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4x64xf32> +func.func @copy_with_tensor_pad_fusion_multi_warp(%source: tensor<121x64xf32>, %init: tensor<4x64xf32>, %off: index, %sz: index, %high: index) -> tensor<4x64xf32> + attributes {hal.executable.target = #exec_target_pad_multi_warp, translation_info = #translation_pad_multi_warp} { + // Extract a dynamic slice. + %extracted = tensor.extract_slice %source[%off, 0] [%sz, 64] [1, 1] + : tensor<121x64xf32> to tensor + + // Pad to static size (only M dimension has padding). + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %extracted low[0, 0] high[%high, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor to tensor<4x64xf32> + + // Copy from padded tensor with 4 warps (256/64=4). + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%padded : tensor<4x64xf32>) + outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32> + + // Key check: With 4 warps available, normal tiling would create a warp-level + // forall with step (1, 64) producing 4 iterations with 1x64 subviews. + // For tensor.pad fusion, we instead create a single-iteration wrapper forall + // with step (4, 64) - the full shape - so the DMA operates on 4x64 directly. + // After canonicalization, identity extract_slices are eliminated. + // + // CHECK: %[[EXTRACTED:.+]] = tensor.extract_slice %[[SRC]] + // CHECK: %[[WARP_RESULT:.+]] = scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (4, 64) step (4, 64) + // CHECK-SAME: shared_outs(%[[INIT_TILE:.+]] = %[[INIT]]) -> (tensor<4x64xf32>) { + // + // Thread-level forall with 64 lanes (uses outer forall's shared_out directly): + // CHECK: %[[THREAD_RESULT:.+]] = scf.forall (%[[LANE:.+]]) in (64) shared_outs(%[[INNER_INIT:.+]] = %[[INIT_TILE]]) + // CHECK: scf.forall.in_parallel { + // CHECK: iree_gpu.coalesced_gather_dma %[[EXTRACTED]] into %[[INNER_INIT]] lane(%[[LANE]]) in_bounds [false, true] + // CHECK-SAME: : tensor, tensor<4x64xf32>, index + // CHECK: } + // CHECK: } {mapping = [#iree_gpu.lane_id<0>]} + // + // CHECK: scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice %[[THREAD_RESULT]] into %[[INIT_TILE]][0, 0] [4, 64] [1, 1] + // CHECK: } + // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} + // CHECK-NOT: tensor.pad + + return %result : tensor<4x64xf32> +} + +// ----- + +// Test: tensor.pad fusion bails out when source row size is not DWORD-aligned. +// On AMD CDNA, per-component range checking is performed for each DWORD. +// If a DWORD is partially out-of-bounds, the entire DWORD returns zero, +// causing incorrect results. We bail out to avoid the slow path. + +#gpu_target_pad_unaligned = #iree_gpu.target> + +#exec_target_pad_unaligned = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_pad_unaligned}> +#translation_pad_unaligned = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_with_tensor_pad_unaligned_row +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<65x121xf16> +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4x124xf16> +func.func @copy_with_tensor_pad_unaligned_row(%source: tensor<65x121xf16>, %init: tensor<4x124xf16>, %off: index, %sz: index, %high_m: index) -> tensor<4x124xf16> + attributes {hal.executable.target = #exec_target_pad_unaligned, translation_info = #translation_pad_unaligned} { + // Extract a dynamic slice: tensor. + // Row size = 121 * 2 bytes = 242 bytes, NOT 4-byte aligned. + %extracted = tensor.extract_slice %source[%off, 0] [%sz, 121] [1, 1] + : tensor<65x121xf16> to tensor + + // Pad to static size. + %cst = arith.constant 0.0 : f16 + %padded = tensor.pad %extracted low[0, 0] high[%high_m, 3] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f16 + } : tensor to tensor<4x124xf16> + + // Copy from padded tensor. + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%padded : tensor<4x124xf16>) + outs(%init : tensor<4x124xf16>) -> tensor<4x124xf16> + + // Source row size (121 * 2 = 242 bytes) is not DWORD-aligned. + // Coalesced DMA bails out to avoid partial OOB in per-DWORD range checking. + // The linalg.copy should remain unchanged. + // CHECK: tensor.pad + // CHECK: linalg.copy + // CHECK-NOT: iree_gpu.coalesced_gather_dma + + return %result : tensor<4x124xf16> +} + +// ----- + +// Test: Copy from load_from_buffer with fat_raw_buffer address space. +// DMA should be applied because fat_raw_buffer indicates the binding fits +// within the 2GB limit required for buffer instructions. + +#gpu_target_fat_raw = #iree_gpu.target> + +#exec_target_fat_raw = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_fat_raw}> +#translation_fat_raw = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_from_fat_raw_buffer +// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]: memref<64x512xbf16, #amdgpu.address_space> +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<64x512xbf16> +func.func @copy_from_fat_raw_buffer( + %buf: memref<64x512xbf16, #amdgpu.address_space>, + %init: tensor<64x512xbf16>) -> tensor<64x512xbf16> + attributes {hal.executable.target = #exec_target_fat_raw, translation_info = #translation_fat_raw} { + %source = iree_codegen.load_from_buffer %buf + : memref<64x512xbf16, #amdgpu.address_space> -> tensor<64x512xbf16> + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%source : tensor<64x512xbf16>) + outs(%init : tensor<64x512xbf16>) -> tensor<64x512xbf16> + + // fat_raw_buffer source allows DMA. + // 2 warps (128/64), 64 rows → step = 32 rows, 512 cols whole. + // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUF]] + // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) {{.*}} shared_outs(%[[WARP_INIT:.+]] = %[[INIT]]) + // CHECK: %[[SLICE_SRC:.+]] = tensor.extract_slice %[[SOURCE]][%[[IV0]], 0] + // CHECK: %[[SLICE_DST:.+]] = tensor.extract_slice %[[WARP_INIT]][%[[IV0]], 0] + // CHECK: scf.forall (%[[LANE:.+]]) in (64) shared_outs(%{{.+}} = %[[SLICE_DST]]) + // CHECK: iree_gpu.coalesced_gather_dma %[[SLICE_SRC]] into %{{.+}} lane(%[[LANE]]) + // CHECK-NOT: linalg.copy + + return %result : tensor<64x512xbf16> +} + +// ----- + +// Test: Small tensor copy from load_from_buffer with fat_raw_buffer. +// Even small tensors should get DMA when innermost dim >= min transfer size. + +#gpu_target_fat_raw_small = #iree_gpu.target> + +#exec_target_fat_raw_small = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_fat_raw_small}> +#translation_fat_raw_small = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_from_fat_raw_buffer_small +// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]: memref<4x256xbf16, #amdgpu.address_space> +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4x256xbf16> +func.func @copy_from_fat_raw_buffer_small( + %buf: memref<4x256xbf16, #amdgpu.address_space>, + %init: tensor<4x256xbf16>) -> tensor<4x256xbf16> + attributes {hal.executable.target = #exec_target_fat_raw_small, translation_info = #translation_fat_raw_small} { + %source = iree_codegen.load_from_buffer %buf + : memref<4x256xbf16, #amdgpu.address_space> -> tensor<4x256xbf16> + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%source : tensor<4x256xbf16>) + outs(%init : tensor<4x256xbf16>) -> tensor<4x256xbf16> + + // Small tensor (4x256 bf16) from fat_raw_buffer. + // Innermost dim 256 >= min transfer (64*32/16=128), so DMA is applied. + // 1 warp (64/64), 4 rows → step = 4 rows, 256 cols whole. + // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUF]] + // CHECK: scf.forall {{.*}} shared_outs(%{{.+}} = %[[INIT]]) + // CHECK: scf.forall (%[[LANE:.+]]) in (64) shared_outs(%[[INNER_INIT:.+]] = + // CHECK: iree_gpu.coalesced_gather_dma %[[SOURCE]] into %[[INNER_INIT]] lane(%[[LANE]]) + // CHECK-NOT: linalg.copy + + return %result : tensor<4x256xbf16> +} + +// ----- + +// Test: Copy from load_from_buffer with storage_buffer (non-fat_raw_buffer). +// DMA should NOT be applied because the source binding was not converted to +// fat_raw_buffer, indicating it exceeds the 2GB limit. + +#gpu_target_storage_buf = #iree_gpu.target> + +#exec_target_storage_buf = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_storage_buf}> +#translation_storage_buf = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_from_non_fat_raw_buffer +// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]: memref<64x512xbf16, #hal.descriptor_type> +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<64x512xbf16> +func.func @copy_from_non_fat_raw_buffer( + %buf: memref<64x512xbf16, #hal.descriptor_type>, + %init: tensor<64x512xbf16>) -> tensor<64x512xbf16> + attributes {hal.executable.target = #exec_target_storage_buf, translation_info = #translation_storage_buf} { + %source = iree_codegen.load_from_buffer %buf + : memref<64x512xbf16, #hal.descriptor_type> -> tensor<64x512xbf16> + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%source : tensor<64x512xbf16>) + outs(%init : tensor<64x512xbf16>) -> tensor<64x512xbf16> + + // storage_buffer source: sourceIsNotFromFatRawBuffer returns true, DMA skipped. + // The linalg.copy should remain unchanged. + // CHECK: linalg.copy + // CHECK-NOT: iree_gpu.coalesced_gather_dma + + return %result : tensor<64x512xbf16> +} + +// ----- + +// Test: Copy from dispatch.tensor.load source. +// DMA should NOT be applied because dispatch.tensor.load indicates the binding +// was not bufferized to a memref (e.g., >2GB binding loaded via dispatch tensor +// path), so fat_raw_buffer is not available. + +#pipeline_layout_dtl = #hal.pipeline.layout, + #hal.pipeline.binding +]> + +#gpu_target_dtl = #iree_gpu.target> + +#exec_target_dtl = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_dtl}> +#translation_dtl = #iree_codegen.translation_info}> + +// CHECK-LABEL: func.func @copy_from_dispatch_tensor_load +func.func @copy_from_dispatch_tensor_load(%init: tensor<64x512xbf16>) -> tensor<64x512xbf16> + attributes {hal.executable.target = #exec_target_dtl, translation_info = #translation_dtl} { + %c0 = arith.constant 0 : index + %binding = hal.interface.binding.subspan layout(#pipeline_layout_dtl) binding(0) alignment(64) offset(%c0) + : !iree_tensor_ext.dispatch.tensor> + %source = iree_tensor_ext.dispatch.tensor.load %binding, offsets = [0, 0], sizes = [64, 512], strides = [1, 1] + : !iree_tensor_ext.dispatch.tensor> -> tensor<64x512xbf16> + %result = linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} + ins(%source : tensor<64x512xbf16>) + outs(%init : tensor<64x512xbf16>) -> tensor<64x512xbf16> + + // dispatch.tensor.load source: sourceIsNotFromFatRawBuffer returns true, + // DMA skipped. The linalg.copy should remain unchanged. + // CHECK: linalg.copy + // CHECK-NOT: iree_gpu.coalesced_gather_dma + + return %result : tensor<64x512xbf16> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp index a3326e9b39ae..f68225712967 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp @@ -205,6 +205,43 @@ Operation *CoalescedGatherDMAOp::getIteratingParent() { return getOperation()->getParentOfType(); } +void CoalescedGatherDMAOp::getEffects( + SmallVectorImpl> + &effects) { + // Get the OpOperand pointers for source and init + // Operand layout: source, indices (variadic), init, lane + unsigned numOperands = getOperation()->getNumOperands(); + unsigned laneOperandIdx = numOperands - 1; + unsigned initOperandIdx = laneOperandIdx - 1; + unsigned sourceOperandIdx = 0; + + Value source = getSource(); + Value init = getInit(); + + // The operation reads from the source. + if (isa(source.getType())) { + effects.emplace_back(MemoryEffects::Read::get(), + &getOperation()->getOpOperand(sourceOperandIdx), + SideEffects::DefaultResource::get()); + } + + // For memref form, the operation writes to init (side effect) + // For tensor form with result, the write is captured in the result value + // For tensor form without result (combiner case in forall.in_parallel), + // we must declare a write effect to prevent DCE from eliminating the op. + if (isa(init.getType())) { + effects.emplace_back(MemoryEffects::Write::get(), + &getOperation()->getOpOperand(initOperandIdx), + SideEffects::DefaultResource::get()); + } else if (isa(init.getType()) && + getOperation()->getNumResults() == 0) { + // Tensor combiner case: declare write effect to prevent DCE. + effects.emplace_back(MemoryEffects::Write::get(), + &getOperation()->getOpOperand(initOperandIdx), + SideEffects::DefaultResource::get()); + } +} + LogicalResult CoalescedGatherDMAOp::verify() { TypedValue init = getInit(); auto initType = init.getType(); @@ -287,7 +324,8 @@ LogicalResult CoalescedGatherDMAOp::verify() { } // Verify the contiguous (non-indexed) dimensions match between source and - // dest. + // dest, unless in_bounds allows OOB reads for that dimension. + std::optional inBoundsAttr = getInBounds(); for (auto [dim, size] : llvm::enumerate(initShape)) { if (dim >= sourceShape.size()) { return emitOpError("expected source to have at least ") @@ -300,6 +338,19 @@ LogicalResult CoalescedGatherDMAOp::verify() { continue; } + // If in_bounds is present and this dimension allows OOB (in_bounds=false), + // skip the size matching check. The source may be smaller than init along + // this dimension, and reads beyond the source extent return zero. + if (inBoundsAttr) { + auto inBoundsArray = *inBoundsAttr; + if (dim < inBoundsArray.size()) { + bool dimInBounds = cast(inBoundsArray[dim]).getValue(); + if (!dimInBounds) { + continue; // OOB allowed, skip size check + } + } + } + // Check the suffix (hidden) gathering dimensions are the same in `source` // and `init`. int64_t sourceDim = sourceShape[dim]; @@ -310,6 +361,16 @@ LogicalResult CoalescedGatherDMAOp::verify() { } } + // Validate in_bounds attribute if present. + if (std::optional inBoundsAttr = getInBounds()) { + int64_t initRank = initShapedType.getRank(); + if (static_cast(inBoundsAttr->size()) != initRank) { + return emitOpError("in_bounds array size (") + << inBoundsAttr->size() << ") must match init rank (" << initRank + << ")"; + } + } + return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td index 4733a70fa383..8d85f3a63e29 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td @@ -228,6 +228,7 @@ def IREEGPU_BufferResourceCastOp : Op, DeclareOpInterfaceMethods]> { let summary = "Coalesced gather DMA operation for efficient GPU memory access"; @@ -299,6 +300,19 @@ def IREEGPU_CoalescedGatherDMAOp : Op, with an intended DMA width of 128 bits @@ -319,13 +333,15 @@ def IREEGPU_CoalescedGatherDMAOp : Op>:$indices, AnyRankedTensorOrMemRef:$init, - Index:$lane + Index:$lane, + OptionalAttr:$in_bounds ); let results = (outs Optional:$result); let assemblyFormat = [{ $source (`[` $indices^ `]`)? `into` $init `lane` `(` $lane `)` + (`in_bounds` $in_bounds^)? attr-dict `:` type(operands) ( `->` type($result)^ )? }]; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/bufferize_coalesced_gather_dma.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/bufferize_coalesced_gather_dma.mlir index dd91b9ed2b10..f8ef24e79839 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/bufferize_coalesced_gather_dma.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/bufferize_coalesced_gather_dma.mlir @@ -73,3 +73,18 @@ func.func @bufferize_coalesced_gather_dma_multiple_indices(%idx0: tensor<4xi32>, // CHECK-LABEL: func @bufferize_coalesced_gather_dma_multiple_indices // CHECK: iree_gpu.coalesced_gather_dma %{{.+}}[%{{.+}}, %{{.+}}] into %{{.+}} lane(%{{.+}}) : memref<64x128xf32{{.+}}>, memref<4xi32{{.+}}>, memref<4xi32{{.+}}>, memref<4x128xf32{{.+}}>, index + +// ----- + +// Test bufferization with in_bounds attribute (for fused tensor.pad). +func.func @bufferize_coalesced_gather_dma_in_bounds(%source: tensor<4x32xf32>, + %dest: tensor<4x64xf32>, + %lane: index) -> tensor<4x64xf32> { + %result = iree_gpu.coalesced_gather_dma %source into %dest lane(%lane) + in_bounds [true, false] + : tensor<4x32xf32>, tensor<4x64xf32>, index -> tensor<4x64xf32> + return %result : tensor<4x64xf32> +} + +// CHECK-LABEL: func @bufferize_coalesced_gather_dma_in_bounds +// CHECK: iree_gpu.coalesced_gather_dma %{{.+}} into %{{.+}} lane(%{{.+}}) in_bounds [true, false] : memref<4x32xf32{{.+}}>, memref<4x64xf32{{.+}}>, index diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp index b1e72d596ba7..cccf381066d0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp @@ -432,12 +432,22 @@ struct CoalescedGatherDMAOpBufferizationInterface } } - rewriter.setInsertionPoint(gatherOp); + // Insert the memref operation in the forall body, before the in_parallel + // terminator (not inside the in_parallel region which will be removed). + auto inParallelOp = gatherOp->getParentOfType(); + if (inParallelOp) { + // Insert before the in_parallel terminator (in the forall body). + rewriter.setInsertionPoint(inParallelOp); + } else { + // Not in in_parallel, just insert at current location. + rewriter.setInsertionPoint(gatherOp); + } // Create the bufferized DMA operation with no results (memref form). IREE::GPU::CoalescedGatherDMAOp::create( rewriter, gatherOp.getLoc(), TypeRange{}, *sourceBuffer, - bufferizedIndices, *initBuffer, gatherOp.getLane()); + bufferizedIndices, *initBuffer, gatherOp.getLane(), + gatherOp.getInBoundsAttr()); // Replace the tensor op. If it has a result, replace with the init buffer. // If it has no result (inside scf.forall.in_parallel), just erase it. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 12c7838bac22..1dc926372c0e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -317,7 +317,8 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter, auto newGatherOp = IREE::GPU::CoalescedGatherDMAOp::create( rewriter, loc, coalescedGather.getInit().getType(), coalescedGather.getSource(), coalescedGather.getIndices(), - coalescedGather.getInit(), coalescedGather.getLane()); + coalescedGather.getInit(), coalescedGather.getLane(), + coalescedGather.getInBoundsAttr()); Value gatherResult = newGatherOp.getResult(); // Use a tensor.insert_slice to insert the gather result back into the @@ -455,7 +456,7 @@ static void composeCoalescedGatherDMA( auto dmaOp = IREE::GPU::CoalescedGatherDMAOp::create( rewriter, warpInsert.getLoc(), destSlice.getType(), laneInsert.getSource(), laneInsert.getIndices(), destSlice, - laneInsert.getLane()); + laneInsert.getLane(), laneInsert.getInBoundsAttr()); // Replace the warp parallel_insert_slice with one that inserts the DMA // result. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 38a1a213ab28..6f78c1ee6732 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -482,6 +482,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); } + // Convert global load DMAs after reduction tiling but before pack + // decomposition. DecomposePackUnPackOps introduces linalg.transpose which + // breaks the source tracing in the coalesced DMA conversion. funcPassManager.addPass(createGPUConvertToCoalescedDMAPass()); // Step 3. Decompose pack and unpack ops and propagate the resulting reshapes. diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index b641388b68f8..c843bab3f432 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1534,6 +1534,97 @@ iree_generated_e2e_runner_test( "requires-gpu-${_CDNA_ARCH}" ) +# Unaligned matmul tests for coalesced DMA with tensor.pad fusion +iree_generated_e2e_runner_test( + NAME + e2e_matmul_${_CDNA_ARCH}_coalesced_dma_f32_unaligned_65x64x121 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f32" + "--acc_type=f32" + "--shapes=custom_mnk" + "--mnk=65,64,121" + "--mnk_dynamicities=static,static,static" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-llvmgpu-use-direct-load" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-${_CDNA_ARCH}" +) + +iree_generated_e2e_runner_test( + NAME + e2e_matmul_${_CDNA_ARCH}_coalesced_dma_f32_unaligned_133x97x65 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f32" + "--acc_type=f32" + "--shapes=custom_mnk" + "--mnk=133,97,65" + "--mnk_dynamicities=static,static,static" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-llvmgpu-use-direct-load" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-${_CDNA_ARCH}" +) + +iree_generated_e2e_runner_test( + NAME + e2e_matmul_${_CDNA_ARCH}_coalesced_dma_f32_unaligned_123x321x231 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f32" + "--acc_type=f32" + "--shapes=custom_mnk" + "--mnk=123,321,231" + "--mnk_dynamicities=static,static,static" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-llvmgpu-use-direct-load" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-${_CDNA_ARCH}" +) + iree_generated_e2e_runner_test( NAME e2e_matmul_${_CDNA_ARCH}_vecdistmfma_f16