diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 107bf3f6fc..5e2fb4a630 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -40,7 +40,10 @@ def is_blackwell(): return is_cuda() and torch.cuda.get_device_capability()[0] == 10 -# FIXME: Revert temporary source code modification done in last commit of PR #4399. +# FIXME: Revert temporary source code modification (only for fp8) done in last commit of PR #4399. +# Note: Triton will fuse load+trans operations, when the data type is fp8, 2D block read aren't generated +# yet because DPAS doesn't natively support fp8. We have to enhance that part of the code generation +# in order to remove the remaining source code changes. @triton.jit @@ -68,7 +71,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, # for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = desc_k.load([0, offsetk_y]) + if dtype == tl.float8e5: + k = desc_k.load([0, offsetk_y]) + else: + k = desc_k.load([offsetk_y, 0]).T qk = tl.dot(q, k) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) @@ -192,8 +198,12 @@ def _attn_fwd(sm_scale, M, # else: desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], block_shape=[BLOCK_N, HEAD_DIM]) - desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM], - block_shape=[HEAD_DIM, BLOCK_N]) + if FP8_OUTPUT: + desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM], + block_shape=[HEAD_DIM, BLOCK_N]) + else: + desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], block_shape=[BLOCK_M, HEAD_DIM]) diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index babdf2e34a..7da3394309 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -70,8 +70,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. // COM: where the 'make_tensor_ptr' result is loop carried. - tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { - %c4_i32 = arith.constant 4 : i32 + tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr, %arg1: !tt.ptr) { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 @@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { %c1_i64 = arith.constant 1 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = arith.divsi %0, %c16_i32 : i32 - %2 = arith.muli %1, %c4_i32 : i32 - %3 = arith.subi %c4_i32, %2 : i32 - %4 = arith.minsi %3, %c4_i32 : i32 - %5 = arith.remsi %0, %c16_i32 : i32 - %6 = arith.remsi %5, %4 : i32 - %7 = arith.addi %2, %6 : i32 - %8 = arith.divsi %5, %4 : i32 %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { @@ -116,13 +106,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // ----- +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { + // COM: tt.load -> tt.trans -> tt.dot chain, in 2 loops. + // COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation. + tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %c32_i32 = arith.constant 32 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %24 = tt.advance %7, [%arg0, %c0_i32] : >> + %25 = tt.load %24 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 { + %adv1 = tt.advance %9, [%arg13, %c0_i32] : > + %load1 = tt.load %adv1 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + %trans1 = tt.trans %load1 {order = array} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %dot1 = tt.dot %25, %trans1, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> + %76 = arith.addi %arg13, %c32_i32 : i32 + scf.yield %76 : i32 + } + %38:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 { + %adv2 = tt.advance %9, [%arg13, %c0_i32] : > + %load2 = tt.load %adv2 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + %trans2 = tt.trans %load2 {order = array} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %dot2 = tt.dot %25, %trans2, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> + %81 = arith.addi %arg13, %c32_i32 : i32 + scf.yield %81 : i32 + } + tt.return + } + // CHECK-LABEL: fuseLoadWithTrans4 + // CHECK-NOT: tt.trans + // CHECK-COUNT-2: tt.make_tensor_ptr %arg2, [%c64_i64, %c1_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: scf.for {{.*}} + // CHECK: [[ADV1:%.*]] = tt.advance {{.*}}, {{.*}} : >> + // CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV1]] {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> + // CHECK: scf.yield {{.*}} + // CHECK: scf.for {{.*}} + // CHECK: [[ADV2:%.*]] = tt.advance {{.*}}, {{.*}} : >> + // CHECK: [[LOAD_B2:%.*]] = tt.load [[ADV2]] {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, [[LOAD_B2]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> + // CHECK: scf.yield {{.*}} +} + +// ----- + #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { - // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load - // COM: that 'feeds' the transpose operation is used. - tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { - %c4_i32 = arith.constant 4 : i32 + // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load that 'feeds' the transpose operation is used. + tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr, %arg1: !tt.ptr) { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 @@ -131,15 +169,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { %c1_i64 = arith.constant 1 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = arith.divsi %0, %c16_i32 : i32 - %2 = arith.muli %1, %c4_i32 : i32 - %3 = arith.subi %c4_i32, %2 : i32 - %4 = arith.minsi %3, %c4_i32 : i32 - %5 = arith.remsi %0, %c16_i32 : i32 - %6 = arith.remsi %5, %4 : i32 - %7 = arith.addi %2, %6 : i32 - %8 = arith.divsi %5, %4 : i32 %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { @@ -166,7 +195,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose. // COM: In this case `%19` is used by the load that feeds the transpose and by a store operation. - tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr, %arg1: !tt.ptr) { %c4_i32 = arith.constant 4 : i32 %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index bff9f17673..feea7b947d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -6,6 +6,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -16,6 +17,7 @@ #include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include @@ -34,6 +36,103 @@ namespace mlir::triton::gpu::intel { namespace { +// Represent a def-use chain rooted at 'start' and terminating at 'end'. +class Chain { + friend raw_ostream &operator<<(raw_ostream &, const Chain &); + using Operations = llvm::SmallSetVector; + +public: + Chain(Operation *start, Operation *end) : start(start), end(end) { + assert(start && end && "Expecting valid operations"); + assert(start != end && "Expecting distinct operations"); + assert( + isTransitivelyUsedBy(start, end) && + "'end' operation should (transitively) use the result of the 'start' " + "operation"); + } + bool operator<(const Chain &other) const { + if (start == other.start) + return end < other.end; + return start < other.start; + } + bool operator==(const Chain &other) const { + return start == other.start && end == other.end; + } + + Operation *getStart() const { return start; } + Operation *getEnd() const { return end; } + + // Returns true if this chain and \p other contain any common operation. + bool overlap(const Chain &other) const { + if (other.getStart() == start || other.getEnd() == end) + return true; + + return isTransitivelyUsedBy(other.getStart(), end); + } + + // Returns true if \p producer yields a result that is used (directly or + // indirectly) by \p consumer. + static bool isTransitivelyUsedBy(Operation *producer, Operation *consumer) { + assert(producer && consumer && "Expecting valid operations"); + + auto addUsers = [](Operation *op, Operations &users) { + assert(op && "Expecting valid operation"); + + auto addUsers = [&](Operation *op) { + // Add users of the block arguments in the 'after' region of a while + // loop. + if (auto condOp = dyn_cast(op)) { + if (auto whileOp = condOp->getParentOfType()) { + for (BlockArgument arg : whileOp.getAfterArguments()) + for (Operation *user : arg.getUsers()) + users.insert(user); + } + } + + for (Operation *user : op->getUsers()) + users.insert(user); + }; + + auto addInitArgsUsers = [&](LoopLikeOpInterface loopOp) { + for (Value val : loopOp.getRegionIterArgs()) + for (Operation *user : val.getUsers()) + addUsers(user); + }; + + if (auto loopOp = dyn_cast(op)) + addInitArgsUsers(loopOp); + else + addUsers(op); + }; + + Operations users; + addUsers(producer, users); + + while (!users.contains(consumer)) { + unsigned currentSize = users.size(); + Operations copyUsers = users; + for (Operation *user : copyUsers) + addUsers(user, users); + + if (users.size() == currentSize) + break; + } + + return users.contains(consumer); + } + +private: + Operation *start = nullptr; + Operation *end = nullptr; +}; + +raw_ostream &operator<<(raw_ostream &os, const Chain &chain) { + os << "[" << chain.start << ", " << chain.end << "]\n"; + os.indent(2) << "start: " << *chain.start << "\n"; + os.indent(2) << "end: " << *chain.end << "\n"; + return os; +} + // Transform: // %ptr = make_block_ptr [shapeN, shapeK], [strideN, strideK], [offN, offK] // : tt.ptr @@ -49,31 +148,169 @@ namespace { // tt.dot(%a, %load) class FuseTransWithLoad { private: - tt::FuncOp funcOp; SmallPtrSet cleanUp; public: - FuseTransWithLoad() = default; + using Chains = std::set; void run(ModuleOp moduleOp) { + Chains chains; + + // Collect def-use chains originating at a `MakeTensorPtrOp` operation + // and terminating at a candidate `tt::TransOp` operation. + // Note: A candidate `TransOp` must use the result of a `LoadOp` using a ptr + // created the `MakeTensorPtrOp` rooting the def-use chain. moduleOp.walk([&](tt::TransOp transOp) { - if (isCandidate(transOp)) - fuse(transOp); + if (isCandidate(transOp)) { + auto loadOp = cast(transOp.getSrc().getDefiningOp()); + tt::MakeTensorPtrOp makeTensorPtrOp = + *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); + Chain chain(makeTensorPtrOp, transOp); + chains.insert(chain); + } }); + if (chains.empty()) + return; + + LLVM_DEBUG({ + llvm::dbgs() << "[Initial set of chains]:\n"; + for (const Chain &chain : chains) + llvm::dbgs() << chain << "\n"; + }); + + // Attempt to duplicate the root operation of chains that overlap (at the + // root), give up if overlap still exist after duplication. + duplicateIfOverlap(chains); + if (overlap(chains)) + return; + + LLVM_DEBUG({ + llvm::dbgs() << "[Before Pruning]:\n"; + for (const Chain &chain : chains) + llvm::dbgs() << chain << "\n"; + }); + + // Prune candidate chains containing load/trans operations that cannot be + // safely fused. + prune(chains); + + LLVM_DEBUG({ + llvm::dbgs() << "[After Pruning]:\n"; + for (const Chain &chain : chains) + llvm::dbgs() << chain << "\n"; + }); + + // Fuse operations. + fuse(chains); + + // Remove operations that are no longer used. if (!cleanUp.empty()) tt::intel::eraseOperations(cleanUp); assert(succeeded(verify(moduleOp)) && "Module verification failed"); } - void fuse(tt::TransOp transOp) { - LLVM_DEBUG(llvm::dbgs() << "Found candidate:\n\t" << transOp << "\n"); + bool overlap(const Chains &chains) const { + assert(!chains.empty() && "Expecting at least one chain"); + if (chains.size() < 2) + return false; + + for (auto it1 = chains.begin(); it1 != chains.end(); ++it1) { + for (auto it2 = it1; it2 != chains.end(); ++it2) { + if (it2 == it1) + continue; + if (it2->overlap(*it1)) + return true; + } + } + + return false; + } + + // Attempt to duplicate operations in the given \p chains if there is an + // overlap. + // Limitation: currently this member function handles overlap at the root + // operation only. + void duplicateIfOverlap(Chains &chains) const { + assert(!chains.empty() && "Expecting at least one chain"); + if (!overlap(chains)) + return; + + LLVM_DEBUG(llvm::dbgs() << "Detected overlap\n";); + + // If the same operation is the root of multiple chains, duplicate it to + // make each chain disjoint from the others. + std::map rootToChains; + for (const Chain &chain : chains) { + Operation *start = chain.getStart(); + if (!rootToChains[start].empty()) + continue; + + Chains sameRootChains{chain}; + rootToChains[start] = sameRootChains; + for (const Chain &otherChain : chains) { + if (otherChain == chain || otherChain.getStart() != start) + continue; + + rootToChains[start].insert(otherChain); + } + } + + for (auto &entry : rootToChains) { + Chains &sameRootChains = entry.second; + if (sameRootChains.size() == 1) + continue; + + duplicateRoot(sameRootChains, chains); + } + } + + // Duplicate the root operation of \p sameRootChains and update \p chains. + void duplicateRoot(Chains &sameRootChains, Chains &chains) const { + assert(sameRootChains.size() > 1 && "expecting at least 2 chains"); + assert(llvm::all_of(sameRootChains, [&](const Chain &chain) { + const Chain &firstChain = *sameRootChains.begin(); + return firstChain.getStart() == chain.getStart(); + })); + + for (auto it = ++sameRootChains.begin(); it != sameRootChains.end(); ++it) { + const Chain &chain = *it; + Operation *start = chain.getStart(); + OpBuilder builder(start); + Operation *duplicate = builder.insert(start->clone()); + assert(start->getNumResults() == 1); + + Value res = start->getResult(0); + Value dupRes = duplicate->getResult(0); + res.replaceUsesWithIf(dupRes, [&](OpOperand &operand) { + return Chain::isTransitivelyUsedBy(operand.getOwner(), chain.getEnd()); + }); + + // remove the chain and insert a new one, rooted by the new operation. + Chain newChain(duplicate, chain.getEnd()); + chains.insert(newChain); + chains.erase(chain); + } + } + + void fuse(const Chains &chains) { + for (const Chain &chain : chains) + fuseTransOpInChain(chain); + } + + void fuseTransOpInChain(const Chain &chain) { + assert( + isa(chain.getStart()) && + "Expecting 'chain' to be rooted by a 'tt.make_tensor_ptr' operation"); + assert(isa(chain.getEnd()) && + "Expecting 'chain' to be terminated by a 'tt.trans' operation"); + + auto makeTensorPtrOp = cast(chain.getStart()); + auto transOp = cast(chain.getEnd()); auto loadOp = cast(transOp.getSrc().getDefiningOp()); - tt::MakeTensorPtrOp makeTensorPtrOp = - *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); LLVM_DEBUG(llvm::dbgs() - << "makeTensorPtrOp:\n\t" << makeTensorPtrOp << "\n"); + << "Fusing:\n " << transOp << "\nwith:\n " << loadOp << "\n"); // Create a MakeTensorPtrOp yielding a block pointer to the transposed // tensor... @@ -89,11 +326,11 @@ class FuseTransWithLoad { Value ptr = builder.create( makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), newShape, newStrides, newOffsets, makeTensorPtrOp.getOrderAttr()); - assert(makeTensorPtrOp->hasOneUse() && "Expecting single user"); - LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp:\n\t" << ptr << "\n"); + LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp:\n " << ptr << "\n"); // ... and propagate it through the def-use chain. - propagateToUsers(ptr, makeTensorPtrOp, makeTensorPtrOp, transOp); + propagateToUsers(ptr, chain); + cleanUp.insert(makeTensorPtrOp); } private: @@ -102,15 +339,11 @@ class FuseTransWithLoad { // Where: // - the transpose result is used by the dot operation, and // - the transpose operation uses the result of a 2-dim load operation on a - // block pointer (transitively) defined by a `make_tensor_ptr` in the same - // function, and - // - each operation in the def-use chain origination at the `make_tensor_ptr` - // and terminating at the load has a single user. + // block pointer (transitively) defined by a `make_tensor_ptr` operation. bool isCandidate(tt::TransOp transOp) const { assert(transOp && "Expecting a valid transpose operation"); - // Check whether \p transOp is used by a `dotOp` directly or indirectly - // (each operation in the def-use chain need to have a single user). + // Check whether \p transOp is used by a `dotOp` (directly or indirectly). auto usedByDotOp = [](tt::TransOp transOp) { if (!transOp->hasOneUse()) return false; @@ -136,12 +369,7 @@ class FuseTransWithLoad { if (!defOp || !isa(defOp)) return false; - return isCandidate(cast(defOp)); - } - - bool isCandidate(tt::LoadOp loadOp) const { - assert(loadOp && "Expecting a valid load operation"); - + auto loadOp = cast(defOp); bool loadOpHasBlockIOAttr = loadOp->hasAttrOfType( ttgi::TritonIntelGPUDialect::getBlockIOAttrName()); if (!loadOp->hasOneUse() || !loadOpHasBlockIOAttr) @@ -152,19 +380,30 @@ class FuseTransWithLoad { cast(ptrType.getPointeeType()).getRank() != 2) return false; - std::optional defOp = + std::optional makeTensorPtrOp = triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); - if (!defOp || !singleUsersInChain(*defOp, loadOp)) - return false; - return true; + return makeTensorPtrOp.has_value(); + } + + // Each operation in the def-use chain must have a single user, except in + // special circumstances. Prune chains that do not satisfy this condition. + void prune(Chains &chains) const { + assert(!chains.empty() && "Expecting at least one candidate chain"); + for (auto it = chains.begin(); it != chains.end();) { + if (!validateChain(*it)) + it = chains.erase(it); + else + ++it; + } } - // Determine whether all operations in the def-use chain from \p start to - // \p end have a single user. + // Determine whether all operations in the given def-use chain have a single + // user. // Note: we allow an operation in the def-use chain to have an additional user - // if the operation is in a for loop, and the additional user is the yield - // operation, provided that the result yielded is not used after the loop. + // if the operation is in a for loop, and the additional user is the loop + // yield operation, provided that the result yielded is not used after the + // loop. // Example: // make_tensor_ptr -> advance -> load (OK) // make_tensor_ptr -> for init_arg -> advance -> load (OK) @@ -172,13 +411,9 @@ class FuseTransWithLoad { // make_tensor_ptr -> for init_arg -> advance -> load (OK) // -> yield -> load (NOT OK) // - bool singleUsersInChain(Operation *start, Operation *end) const { - assert(start && end && "Expecting valid operations"); - Operation *currentOp = start; - - auto validate = [](Operation *op, Operation *&nextOp) { + bool validateChain(const Chain &chain) const { + auto validateOperation = [](Operation *op, Operation *&nextOp) { assert(nextOp == nullptr); - if (op->hasOneUse()) return true; if (!op->getParentOfType()) @@ -214,9 +449,10 @@ class FuseTransWithLoad { return true; }; - while (currentOp != end) { + Operation *currentOp = chain.getStart(); + while (currentOp != chain.getEnd()) { Operation *user = nullptr; - if (!validate(currentOp, user)) { + if (!validateOperation(currentOp, user)) { LLVM_DEBUG(llvm::dbgs() << "Fails safety checks: " << *currentOp << "\n"); return false; @@ -228,6 +464,7 @@ class FuseTransWithLoad { continue; } + // Current limitation: give up if the use is a branch. if (isa(user)) return false; @@ -253,6 +490,19 @@ class FuseTransWithLoad { return true; } + // Propagate \p newVal to operations in the given def-use chain. + void propagateToUsers(Value newVal, const Chain &chain) { + auto start = cast(chain.getStart()); + Operation *end = chain.getEnd(); + auto it = llvm::find_if(start->getUsers(), [&](Operation *user) { + return Chain::isTransitivelyUsedBy(user, end); + }); + assert(it != start->getUsers().end() && "Expecting valid iterator"); + + Operation *nextOp = *it; + propagateToUser(newVal, start.getResult(), nextOp, end); + } + // Propagate \p newVal to users of \p origOp. void propagateToUsers(Value newVal, Value origVal, Operation *origOp, Operation *sentinel) { @@ -271,14 +521,14 @@ class FuseTransWithLoad { LLVM_DEBUG({ llvm::dbgs() << "In " << __func__ << "\n"; - llvm::dbgs() << "user of "; + llvm::dbgs() << "user of:"; if (origVal.getDefiningOp()) { - llvm::dbgs() << "\n\t" << *origVal.getDefiningOp() << "\n"; + llvm::dbgs() << "\n " << *origVal.getDefiningOp() << "\n"; } else { origVal.printAsOperand(llvm::dbgs(), {}); llvm::dbgs() << " "; } - llvm::dbgs() << "is:\n\t"; + llvm::dbgs() << "is:\n "; user->dumpPretty(); }); @@ -295,7 +545,8 @@ class FuseTransWithLoad { SmallVector newOffsets(llvm::reverse(advanceOp.getOffsets())); auto newAdvanceOp = rewriter.create(loc, newVal.getType(), newVal, newOffsets); - LLVM_DEBUG(llvm::dbgs() << "\tnewAdvanceOp: " << newAdvanceOp << "\n"); + LLVM_DEBUG(llvm::dbgs().indent(2) + << "newAdvanceOp: " << newAdvanceOp << "\n"); cleanUp.insert(advanceOp); return propagateToUsers(newAdvanceOp, advanceOp.getResult(), advanceOp, sentinel); @@ -320,7 +571,7 @@ class FuseTransWithLoad { assert(newAttr && "Expecting a valid blockIO attribute"); newLoadOp->setAttr(blockIOAttrName, newAttr); - LLVM_DEBUG(llvm::dbgs() << "\tnewLoadOp: " << newLoadOp << "\n"); + LLVM_DEBUG(llvm::dbgs().indent(2) << "newLoadOp: " << newLoadOp << "\n"); cleanUp.insert(loadOp); return propagateToUsers(newLoadOp, loadOp.getResult(), loadOp, sentinel); }