44// TODO: refactor so that it doesn't fail if Allocation.h
55// is included after utility.h (due to conflict in `store` macro
66// and <atomic>
7- #include " triton/Analysis/Allocation.h"
7+ #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
8+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
9+ #include " mlir/Transforms/DialectConversion.h"
810
911//
1012#include " Utility.h"
1113#include " mlir/IR/TypeUtilities.h"
12-
13- #include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
14- #include " triton/Analysis/AxisInfo.h"
15- #include < set>
1614#include < type_traits>
1715
1816#define DEBUG_TYPE " ttgpu_to_llvm"
@@ -26,11 +24,95 @@ using ::mlir::triton::gpu::BlockedEncodingAttr;
2624using ::mlir::triton::gpu::CTALayoutAttr;
2725using ::mlir::triton::gpu::DotOperandEncodingAttr;
2826using ::mlir::triton::gpu::SliceEncodingAttr;
29- using ::mlir::triton::gpu::intel::DpasEncodingAttr;
3027
3128namespace mlir ::triton {
3229class ReduceOp ;
3330class ScanOp ;
31+
32+ inline SmallVector<Value>
33+ inlineCombineBlock (ConversionPatternRewriter &rewriter, Block &combineBlock,
34+ Block *insertionBlock, Block::iterator insertionPoint,
35+ ValueRange combineArgs) {
36+ auto returnOp = combineBlock.getTerminator ();
37+ rewriter.inlineBlockBefore (&combineBlock, insertionBlock, insertionPoint,
38+ combineArgs);
39+
40+ auto results = SmallVector<Value>(returnOp->getOperands ());
41+
42+ // Delete the terminator, which is no longer used
43+ rewriter.eraseOp (returnOp);
44+ return results;
45+ }
46+
47+ inline SmallVector<Value> applyCombineOp (Location loc,
48+ ConversionPatternRewriter &rewriter,
49+ Region &combineOp, ValueRange acc,
50+ ValueRange cur, Value pred = {}) {
51+ // Allows for passing an unitialized acc and use cur as the neutral element
52+ if (acc.size () == 0 ) {
53+ return cur;
54+ }
55+ assert (cur.size () == acc.size ());
56+
57+ // Create a new copy of the combine block, and try to speculatively inline it
58+ Block *currentBlock = rewriter.getBlock ();
59+ Region &parent = *currentBlock->getParent ();
60+
61+ rewriter.cloneRegionBefore (combineOp, parent,
62+ std::next (currentBlock->getIterator ()));
63+ Block &newCombine = *currentBlock->getNextNode ();
64+
65+ llvm::SmallVector<Value> combineArgs (2 * acc.size ());
66+ for (unsigned i = 0 ; i < acc.size (); ++i) {
67+ combineArgs[i] = acc[i];
68+ combineArgs[acc.size () + i] = cur[i];
69+ }
70+
71+ auto isRegionSpeculatable =
72+ std::all_of (newCombine.begin (), newCombine.end (),
73+ [](auto &op) { return isSpeculatable (&op); });
74+
75+ if (!pred || isRegionSpeculatable) {
76+ // Fast path, region has no side effects so we can unconditionally execute
77+ return inlineCombineBlock (rewriter, newCombine, currentBlock,
78+ rewriter.getInsertionPoint (), combineArgs);
79+ }
80+
81+ // Slow case, create an if to only execute region when pred is true
82+ // #currentBlock
83+ // if (pred) {
84+ // #newCombine
85+ // results = combineOp(cur, acc)
86+ // yield results
87+ // } else {
88+ // yield undef
89+ // }
90+ // #thenBlock
91+ Block *thenBlock =
92+ rewriter.splitBlock (currentBlock, rewriter.getInsertionPoint ());
93+
94+ auto returnOp = newCombine.getTerminator ();
95+ auto results = SmallVector<Value>(returnOp->getOperands ());
96+
97+ rewriter.setInsertionPointToEnd (currentBlock);
98+ SmallVector<Value> thenBlockArgs;
99+ thenBlockArgs.reserve (results.size ());
100+ for (auto result : results) {
101+ auto ty = result.getType ();
102+ auto undef = rewriter.create <LLVM::UndefOp>(loc, ty);
103+ thenBlockArgs.push_back (undef);
104+ thenBlock->addArgument (ty, loc);
105+ }
106+ rewriter.create <cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
107+ thenBlock, thenBlockArgs);
108+
109+ // Split a block after the call.
110+ rewriter.setInsertionPointToEnd (&newCombine);
111+ rewriter.replaceOpWithNewOp <cf::BranchOp>(returnOp, thenBlock, results);
112+ rewriter.setInsertionPointToStart (thenBlock);
113+ return SmallVector<Value>(thenBlock->getArguments ());
114+ }
115+
34116} // namespace mlir::triton
35117
36118template <typename SourceOp>
@@ -66,12 +148,13 @@ class ConvertTritonIntelGPUReduceScanToLLVMPattern
66148 });
67149 // Assign base index to each operand in their order in indices
68150 std::map<unsigned , Value> indexToBase;
69- indexToBase[indices[0 ]] = LLVM::intel::getSharedMemoryBase (
70- loc, rewriter, targetInfo, op.getOperation ());
151+ auto basePtr = LLVM::intel::getSharedMemoryBase (loc, rewriter, targetInfo,
152+ op.getOperation ());
153+ indexToBase[indices[0 ]] = basePtr;
71154 for (unsigned i = 1 ; i < op.getNumOperands (); ++i) {
72- indexToBase[indices[i]] = gep (
73- ptr_ty (rewriter. getContext (), 3 ), getElementType (op, indices[i - 1 ]),
74- indexToBase[indices[i - 1 ]], i32_val (elems));
155+ indexToBase[indices[i]] =
156+ gep (basePtr. getType ( ), getElementType (op, indices[i - 1 ]),
157+ indexToBase[indices[i - 1 ]], i32_val (elems));
75158 }
76159 // smemBases[k] is the base pointer for the k-th operand
77160 SmallVector<Value> smemBases (op.getNumOperands ());
0 commit comments