|
1 | 1 | #ifndef TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_ |
2 | 2 | #define TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_ |
3 | 3 |
|
4 | | -#include <functional> |
5 | | -#include <optional> |
6 | | -#include <tuple> |
| 4 | +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" |
7 | 5 |
|
8 | 6 | namespace mlir { |
| 7 | + |
9 | 8 | class OpBuilder; |
10 | | -class Operation; |
| 9 | +class DominanceInfo; |
11 | 10 |
|
12 | 11 | namespace scf { |
13 | 12 | class ForOp; |
14 | | -} |
| 13 | +} // namespace scf |
15 | 14 | namespace triton::nvidia_gpu { |
16 | | -class MMAv5OpInterface; |
17 | | -class TMEMAllocOp; |
18 | | -class TMEMLoadOp; |
| 15 | + |
| 16 | +//===----------------------------------------------------------------------===// |
| 17 | +// MMAInfo |
| 18 | +//===----------------------------------------------------------------------===// |
| 19 | + |
| 20 | +// This struct contains analysis information about an MMAv5 operation inside a |
| 21 | +// loop used for pipelining MMA ops. |
| 22 | +struct MMAInfo { |
| 23 | + // This struct contains information about when the MMA's accumulator is |
| 24 | + // overridden in the loop, if it is at all. |
| 25 | + struct AccOverridePoint { |
| 26 | + // The operation which overrides the accumulator. |
| 27 | + Operation *op; |
| 28 | + // The condition on which the accumulator is reset. |
| 29 | + Value condition = nullptr; |
| 30 | + // The initial value of the accumulator and the value after a reset. |
| 31 | + Value initValue = nullptr; |
| 32 | + // The number of loop iterations ago the accumulator was reset. |
| 33 | + int distance = 0; |
| 34 | + // Whether the accumulator is reset via setting the `useAcc` flag to false |
| 35 | + // or by clearing the accumulator tensor value. |
| 36 | + bool isFlag = false; |
| 37 | + }; |
| 38 | + |
| 39 | + // The TMEM allocation of the accumuator, which directly precedes the dot op. |
| 40 | + TMEMAllocOp accAlloc; |
| 41 | + // The TMEM load of the accumulator value out of TMEM, which directly follows |
| 42 | + // the dot op. |
| 43 | + TMEMLoadOp accLoad; |
| 44 | + // The override point of the accumulator value, if it is overriden in the |
| 45 | + // loop. E.g. this is typically present for persistent kernels. |
| 46 | + std::optional<AccOverridePoint> accDef; |
| 47 | + // If the accumulator is used in future iterations of the loop, this is the |
| 48 | + // iter arg number. |
| 49 | + std::optional<int> yieldArgNo; |
| 50 | + // Whether the accumulator needs to be multibuffered. |
| 51 | + bool accIsMultiBuffered; |
| 52 | + |
| 53 | + Value phase = nullptr; |
| 54 | + Value barrierIdx = nullptr; |
| 55 | + Value accInsertIdx = nullptr; |
| 56 | + Value accExtractIdx = nullptr; |
| 57 | + Value barrierAlloc = nullptr; |
| 58 | +}; |
| 59 | + |
| 60 | +//===----------------------------------------------------------------------===// |
| 61 | +// MMA Pipeline Analysis |
| 62 | +//===----------------------------------------------------------------------===// |
19 | 63 |
|
20 | 64 | // Returns the TMEMAllocOp and TMEMLoadOp that are used to allocate and load the |
21 | 65 | // accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must |
22 | 66 | // be in the same region as the MMA operation. |
23 | 67 | std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>> |
24 | 68 | getTMemAllocAndLoad(MMAv5OpInterface mmaOp); |
| 69 | +// Get immediate users of the accumulator within the current loop iteration. |
| 70 | +SmallVector<Operation *> getDirectAccUses(TMEMLoadOp accDef); |
| 71 | +// Analyze an MMA op inside a loop to determine information about how it can be |
| 72 | +// pipelined. Returns `std::nullopt` if it cannot be pipelined. |
| 73 | +std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp, |
| 74 | + DominanceInfo &domInfo); |
| 75 | + |
| 76 | +//===----------------------------------------------------------------------===// |
| 77 | +// MMA Pipeline Rewriters |
| 78 | +//===----------------------------------------------------------------------===// |
| 79 | + |
25 | 80 | // Create a new TMEMAllocOp to use for the pipelined MMA operation. It is |
26 | 81 | // optionally multi-buffered based on the number of stages. |
27 | 82 | TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp, |
28 | 83 | bool multiBufferred, int numStages); |
29 | 84 |
|
| 85 | +// Create a store op of the initial value of the accumulator into the |
| 86 | +// potentially multi-buffered accumulator. |
| 87 | +void createInitStore(OpBuilder &builder, TMEMAllocOp allocOp, Value initVal, |
| 88 | + bool multiBufferred); |
| 89 | + |
30 | 90 | // Return true if operands of the MMA operation are/are going to be pipelined |
31 | 91 | // and multibuffered, enabling the MMA operation to be pipelined. |
32 | 92 | bool mmaHasPipelineableOperands( |
33 | 93 | MMAv5OpInterface mma, scf::ForOp forOp, |
34 | 94 | std::function<bool(Operation *)> isLoadPipelineable); |
35 | 95 |
|
36 | | -// Return true if the loop has a read-modify-write access to the accumulator. |
| 96 | +// Return true if the accumulator of an mma in subsequent iterations is either |
| 97 | +// independent from the previous iteration (overwritten) or completely reused, |
| 98 | +// without read-modify-write. |
| 99 | +// Otherwise, we can not pipeline the MMA, as we need to insert a wait after the |
| 100 | +// mma to read back the accumulator for RMW. |
37 | 101 | bool hasAccReadModifyWrite(MMAv5OpInterface mma, scf::ForOp forOp); |
| 102 | + |
38 | 103 | } // namespace triton::nvidia_gpu |
39 | 104 | } // namespace mlir |
40 | 105 |
|
|
0 commit comments