|
23 | 23 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
24 | 24 | #include "mlir/IR/Builders.h" |
25 | 25 | #include "mlir/IR/BuiltinAttributes.h" |
| 26 | +#include "mlir/IR/Matchers.h" |
26 | 27 | #include "mlir/IR/PatternMatch.h" |
27 | 28 | #include "mlir/Pass/Pass.h" |
28 | 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
@@ -58,6 +59,15 @@ static SmallVector<Attribute> getThreadMapping(MLIRContext *ctx) { |
58 | 59 | return mapping; |
59 | 60 | } |
60 | 61 |
|
| 62 | +/// Trace through extract_slice operations to find an underlying tensor.pad. |
| 63 | +/// Returns the PadOp if found, nullptr otherwise. |
| 64 | +static tensor::PadOp traceToTensorPad(Value source) { |
| 65 | + while (auto extractSlice = source.getDefiningOp<tensor::ExtractSliceOp>()) { |
| 66 | + source = extractSlice.getSource(); |
| 67 | + } |
| 68 | + return source.getDefiningOp<tensor::PadOp>(); |
| 69 | +} |
| 70 | + |
61 | 71 | /// Check if a value traces back to tensor.empty (possibly through forall args). |
62 | 72 | static bool tracesToTensorEmpty(Value value) { |
63 | 73 | // Direct tensor.empty. |
@@ -300,14 +310,74 @@ static LogicalResult createDMAInForall(scf::ForallOp threadForallOp, |
300 | 310 |
|
301 | 311 | Location loc = innerOp.getLoc(); |
302 | 312 | Value source, indices; |
| 313 | + SmallVector<bool> inBoundsVec; |
303 | 314 |
|
304 | 315 | // Extract source and indices based on op type. |
305 | 316 | if constexpr (std::is_same_v<OpTy, linalg::CopyOp>) { |
306 | 317 | Value input = innerOp.getInputs()[0]; |
307 | | - if (auto extractSlice = input.getDefiningOp<tensor::ExtractSliceOp>()) { |
308 | | - source = extractSlice.getSource(); |
309 | | - } else { |
310 | | - return failure(); |
| 318 | + |
| 319 | + // After tiling, the input is typically: |
| 320 | + // tensor.extract_slice %padded[...] [...] [1, 1] |
| 321 | + // We need to trace through extract_slice to find if source is tensor.pad. |
| 322 | + if (auto pad = traceToTensorPad(input)) { |
| 323 | + // Verify pad constraints: low padding must be all zeros, pad value must |
| 324 | + // be 0. |
| 325 | + bool validPad = true; |
| 326 | + for (OpFoldResult low : pad.getMixedLowPad()) { |
| 327 | + if (!isConstantIntValue(low, 0)) { |
| 328 | + validPad = false; |
| 329 | + break; |
| 330 | + } |
| 331 | + } |
| 332 | + Value padVal = pad.getConstantPaddingValue(); |
| 333 | + if (!padVal || !(matchPattern(padVal, m_AnyZeroFloat()) || |
| 334 | + matchPattern(padVal, m_Zero()))) { |
| 335 | + validPad = false; |
| 336 | + } |
| 337 | + |
| 338 | + if (validPad) { |
| 339 | + // Use pad.getSource() directly as the DMA source. |
| 340 | + // This is the tensor.extract_slice result (e.g., tensor<?x64xf32>). |
| 341 | + source = pad.getSource(); |
| 342 | + |
| 343 | + // Check if source tensor's innermost row size is DWORD (4-byte) |
| 344 | + // aligned. On AMD CDNA, per-component range checking is performed for |
| 345 | + // each DWORD. If a DWORD is partially out-of-bounds, the entire DWORD |
| 346 | + // returns zero, causing incorrect results. Additionally, partial OOB |
| 347 | + // triggers the slow path with multi-cycling and instruction issue |
| 348 | + // penalties. |
| 349 | + auto sourceType = cast<RankedTensorType>(source.getType()); |
| 350 | + int64_t innermostDim = sourceType.getShape().back(); |
| 351 | + if (!ShapedType::isDynamic(innermostDim)) { |
| 352 | + Type elemType = sourceType.getElementType(); |
| 353 | + int64_t elemBytes = elemType.getIntOrFloatBitWidth() / 8; |
| 354 | + int64_t rowBytes = innermostDim * elemBytes; |
| 355 | + if (rowBytes % 4 != 0) { |
| 356 | + LLVM_DEBUG(llvm::dbgs() |
| 357 | + << "Skipping DMA: row size " << rowBytes |
| 358 | + << " bytes not DWORD-aligned (slow path)\n"); |
| 359 | + return failure(); |
| 360 | + } |
| 361 | + } |
| 362 | + |
| 363 | + // Compute in_bounds based on whether padding was added per dimension. |
| 364 | + for (auto [low, high] : |
| 365 | + llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) { |
| 366 | + bool isInBounds = |
| 367 | + isConstantIntValue(low, 0) && isConstantIntValue(high, 0); |
| 368 | + inBoundsVec.push_back(isInBounds); |
| 369 | + } |
| 370 | + } |
| 371 | + } |
| 372 | + |
| 373 | + // Fallback: original behavior without tensor.pad fusion. |
| 374 | + // Only trace through ONE level of extract_slice (the immediate input). |
| 375 | + if (!source) { |
| 376 | + if (auto extractSlice = input.getDefiningOp<tensor::ExtractSliceOp>()) { |
| 377 | + source = extractSlice.getSource(); |
| 378 | + } else { |
| 379 | + return failure(); |
| 380 | + } |
311 | 381 | } |
312 | 382 | } else if constexpr (std::is_same_v<OpTy, IREE::LinalgExt::GatherOp>) { |
313 | 383 | source = innerOp.getSource(); |
@@ -356,15 +426,22 @@ static LogicalResult createDMAInForall(scf::ForallOp threadForallOp, |
356 | 426 |
|
357 | 427 | // Create the DMA op in the in_parallel region. |
358 | 428 | rewriter.setInsertionPointToStart(&inParallelBlock); |
359 | | - SmallVector<Value, 1> indicesVec; |
| 429 | + SmallVector<Value, 1> indicesOperands; |
360 | 430 | if (indices) { |
361 | | - indicesVec.push_back(indices); |
| 431 | + indicesOperands.push_back(indices); |
| 432 | + } |
| 433 | + |
| 434 | + // Create in_bounds attribute if we fused a tensor.pad. |
| 435 | + ArrayAttr inBoundsAttr; |
| 436 | + if (!inBoundsVec.empty()) { |
| 437 | + inBoundsAttr = rewriter.getBoolArrayAttr(inBoundsVec); |
362 | 438 | } |
363 | 439 |
|
364 | 440 | // When used in forall.in_parallel, the op doesn't return a result |
365 | 441 | // as it performs an in-place update to the shared_outs tensor. |
366 | 442 | IREE::GPU::CoalescedGatherDMAOp::create(rewriter, loc, Type(), source, |
367 | | - indicesVec, sharedOut, laneId); |
| 443 | + indicesOperands, sharedOut, laneId, |
| 444 | + inBoundsAttr); |
368 | 445 |
|
369 | 446 | // Erase the parallel_insert_slice ops and inner operation. |
370 | 447 | for (tensor::ParallelInsertSliceOp &insertOp : toErase) { |
@@ -421,6 +498,58 @@ struct ConvertCopyToCoalescedDMA |
421 | 498 | } |
422 | 499 | }; |
423 | 500 |
|
| 501 | +/// Pattern to convert tensor.pad fusion cases directly without requiring |
| 502 | +/// warp-mapped forall parent. |
| 503 | +struct ConvertPadFusionCopyToCoalescedDMA |
| 504 | + : public OpRewritePattern<linalg::CopyOp> { |
| 505 | + using OpRewritePattern::OpRewritePattern; |
| 506 | + |
| 507 | + LogicalResult matchAndRewrite(linalg::CopyOp copyOp, |
| 508 | + PatternRewriter &rewriter) const override { |
| 509 | + // Only match copies with use_global_load_dma config |
| 510 | + auto config = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(copyOp); |
| 511 | + if (!config) { |
| 512 | + return failure(); |
| 513 | + } |
| 514 | + |
| 515 | + // Check if this is a tensor.pad fusion case |
| 516 | + auto pad = traceToTensorPad(copyOp.getInputs()[0]); |
| 517 | + if (!pad) { |
| 518 | + return failure(); // Not a pad fusion case |
| 519 | + } |
| 520 | + |
| 521 | + // Check if padding exists (non-zero low/high pad) |
| 522 | + bool hasPadding = false; |
| 523 | + for (auto [low, high] : |
| 524 | + llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) { |
| 525 | + if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) { |
| 526 | + hasPadding = true; |
| 527 | + break; |
| 528 | + } |
| 529 | + } |
| 530 | + if (!hasPadding) { |
| 531 | + return failure(); // No actual padding |
| 532 | + } |
| 533 | + |
| 534 | + // This is a tensor.pad fusion case. Convert directly to |
| 535 | + // coalesced_gather_dma without requiring warp-mapped forall. |
| 536 | + auto outputType = cast<RankedTensorType>(copyOp.getOutputs()[0].getType()); |
| 537 | + SmallVector<OpFoldResult> threadNumThreads = |
| 538 | + computeThreadNumThreadsImpl(rewriter, copyOp, outputType); |
| 539 | + if (threadNumThreads.empty()) { |
| 540 | + return failure(); |
| 541 | + } |
| 542 | + |
| 543 | + scf::ForallOp threadForallOp = |
| 544 | + tileToThreadLevel(copyOp, rewriter, threadNumThreads); |
| 545 | + if (!threadForallOp) { |
| 546 | + return failure(); |
| 547 | + } |
| 548 | + |
| 549 | + return createDMAInForall<linalg::CopyOp>(threadForallOp, rewriter); |
| 550 | + } |
| 551 | +}; |
| 552 | + |
424 | 553 | struct ConvertGatherToCoalescedDMA |
425 | 554 | : public OpRewritePattern<IREE::LinalgExt::GatherOp> { |
426 | 555 | using OpRewritePattern<IREE::LinalgExt::GatherOp>::OpRewritePattern; |
@@ -574,7 +703,8 @@ struct ConvertGatherToCoalescedDMA |
574 | 703 | rewriter.setInsertionPointToStart(&inParallelBlock); |
575 | 704 |
|
576 | 705 | IREE::GPU::CoalescedGatherDMAOp::create(rewriter, loc, Type(), source, |
577 | | - indicesVec, sharedOut, laneId); |
| 706 | + indicesVec, sharedOut, laneId, |
| 707 | + /*in_bounds=*/nullptr); |
578 | 708 |
|
579 | 709 | // Erase parallel_insert_slice ops and gather op. |
580 | 710 | SmallVector<tensor::ParallelInsertSliceOp> toErase; |
@@ -605,9 +735,11 @@ struct GPUConvertToCoalescedDMAPass final |
605 | 735 | } |
606 | 736 |
|
607 | 737 | // Only tile and convert ops within forall ops with warp mapping. |
| 738 | + // Also handle tensor.pad fusion cases that don't have warp mapping. |
608 | 739 | RewritePatternSet patterns(context); |
609 | 740 | patterns.add<ConvertGatherToCoalescedDMA>(context); |
610 | 741 | patterns.add<ConvertCopyToCoalescedDMA>(context); |
| 742 | + patterns.add<ConvertPadFusionCopyToCoalescedDMA>(context); |
611 | 743 |
|
612 | 744 | walkAndApplyPatterns(funcOp, std::move(patterns)); |
613 | 745 | } |
@@ -758,9 +890,42 @@ struct GPUConvertToCoalescedDMAPass final |
758 | 890 | return failure(); |
759 | 891 | } |
760 | 892 |
|
761 | | - // Compute tile sizes for subgroup-level distribution. |
762 | | - auto [tileSizes, numTiledDims] = |
763 | | - computeSubgroupTileSizes(rewriter, shape, numWarps); |
| 893 | + // Check if this is a tensor.pad fusion case. |
| 894 | + bool isPadFusion = false; |
| 895 | + if (auto copyOp = dyn_cast<linalg::CopyOp>(op.getOperation())) { |
| 896 | + if (auto pad = traceToTensorPad(copyOp.getInputs()[0])) { |
| 897 | + // Check if padding exists (non-zero low/high pad) |
| 898 | + for (auto [low, high] : |
| 899 | + llvm::zip(pad.getMixedLowPad(), pad.getMixedHighPad())) { |
| 900 | + if (!isConstantIntValue(low, 0) || !isConstantIntValue(high, 0)) { |
| 901 | + isPadFusion = true; |
| 902 | + break; |
| 903 | + } |
| 904 | + } |
| 905 | + } |
| 906 | + } |
| 907 | + |
| 908 | + SmallVector<OpFoldResult> tileSizes; |
| 909 | + int64_t numTiledDims = 0; |
| 910 | + |
| 911 | + if (isPadFusion) { |
| 912 | + // For tensor.pad fusion, create a single-iteration wrapper forall |
| 913 | + // by setting tile sizes to the full shape. This allows the DMA to |
| 914 | + // operate on the full buffer while satisfying the warp-mapped parent |
| 915 | + // requirement. |
| 916 | + // Bail out if any dimension is dynamic since we need static tile sizes. |
| 917 | + if (llvm::any_of(shape, ShapedType::isDynamic)) { |
| 918 | + return failure(); |
| 919 | + } |
| 920 | + for (int64_t i = 0; i < rank; ++i) { |
| 921 | + tileSizes.push_back(rewriter.getIndexAttr(shape[i])); |
| 922 | + ++numTiledDims; |
| 923 | + } |
| 924 | + } else { |
| 925 | + // Compute tile sizes for subgroup-level distribution. |
| 926 | + std::tie(tileSizes, numTiledDims) = |
| 927 | + computeSubgroupTileSizes(rewriter, shape, numWarps); |
| 928 | + } |
764 | 929 |
|
765 | 930 | if (numTiledDims == 0) { |
766 | 931 | return failure(); |
@@ -798,6 +963,9 @@ struct GPUConvertToCoalescedDMAPass final |
798 | 963 | }); |
799 | 964 |
|
800 | 965 | // Apply subgroup-level tiling to each op. |
| 966 | + // For tensor.pad fusion cases, tileAtSubgroupLevel creates a |
| 967 | + // single-iteration wrapper forall to maintain the expected structure while |
| 968 | + // allowing the DMA to operate on the full buffer. |
801 | 969 | IRRewriter rewriter(context); |
802 | 970 | for (Operation *op : opsToTile) { |
803 | 971 | FailureOr<scf::SCFTilingResult> tilingResult = |
|
0 commit comments