@@ -13,50 +13,6 @@ class ForOp;
1313} // namespace scf
1414namespace triton ::nvidia_gpu {
1515
16- // ===----------------------------------------------------------------------===//
17- // MMAInfo
18- // ===----------------------------------------------------------------------===//
19-
20- // This struct contains analysis information about an MMAv5 operation inside a
21- // loop used for pipelining MMA ops.
22- struct MMAInfo {
23- // This struct contains information about when the MMA's accumulator is
24- // overridden in the loop, if it is at all.
25- struct AccOverridePoint {
26- // The operation which overrides the accumulator.
27- Operation *op;
28- // The condition on which the accumulator is reset.
29- Value condition = nullptr ;
30- // The initial value of the accumulator and the value after a reset.
31- Value initValue = nullptr ;
32- // The number of loop iterations ago the accumulator was reset.
33- int distance = 0 ;
34- // Whether the accumulator is reset via setting the `useAcc` flag to false
35- // or by clearing the accumulator tensor value.
36- bool isFlag = false ;
37- };
38-
39- // The TMEM allocation of the accumuator, which directly precedes the dot op.
40- TMEMAllocOp accAlloc;
41- // The TMEM load of the accumulator value out of TMEM, which directly follows
42- // the dot op.
43- TMEMLoadOp accLoad;
44- // The override point of the accumulator value, if it is overriden in the
45- // loop. E.g. this is typically present for persistent kernels.
46- std::optional<AccOverridePoint> accDef;
47- // If the accumulator is used in future iterations of the loop, this is the
48- // iter arg number.
49- std::optional<int > yieldArgNo;
50- // Whether the accumulator needs to be multibuffered.
51- bool accIsMultiBuffered;
52-
53- Value phase = nullptr ;
54- Value barrierIdx = nullptr ;
55- Value accInsertIdx = nullptr ;
56- Value accExtractIdx = nullptr ;
57- Value barrierAlloc = nullptr ;
58- };
59-
6016// ===----------------------------------------------------------------------===//
6117// MMA Pipeline Analysis
6218// ===----------------------------------------------------------------------===//
@@ -66,12 +22,14 @@ struct MMAInfo {
6622// be in the same region as the MMA operation.
6723std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
6824getTMemAllocAndLoad (MMAv5OpInterface mmaOp);
69- // Get immediate users of the accumulator within the current loop iteration.
70- SmallVector<Operation *> getDirectAccUses (TMEMLoadOp accDef);
71- // Analyze an MMA op inside a loop to determine information about how it can be
72- // pipelined. Returns `std::nullopt` if it cannot be pipelined.
73- std::optional<MMAInfo> getMMAInfo (scf::ForOp forOp, MMAv5OpInterface mmaOp,
74- DominanceInfo &domInfo);
25+ // Given an MMAv5 operation in a loop, determine if its accumulator can be
26+ // multibuffered.
27+ bool isAccMultibufferingPossible (MMAv5OpInterface mma, scf::ForOp forOp);
28+ // Only pipeline the loops where the MMA happens before the tmem_load, or is in
29+ // the same stage as the tmem_load. Lowering does not support the case where the
30+ // MMA is in a different stage as the tmem_load and happens after it.
31+ bool mmav5DominatesTmemLoads (
32+ scf::ForOp forOp, function_ref<bool (MMAv5OpInterface)> isMmaPipelineable);
7533
7634// ===----------------------------------------------------------------------===//
7735// MMA Pipeline Rewriters
@@ -82,11 +40,6 @@ std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
8240TMEMAllocOp createTMemAlloc (OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
8341 bool multiBufferred, int numStages);
8442
85- // Create a store op of the initial value of the accumulator into the
86- // potentially multi-buffered accumulator.
87- void createInitStore (OpBuilder &builder, TMEMAllocOp allocOp, Value initVal,
88- bool multiBufferred);
89-
9043// Return true if operands of the MMA operation are/are going to be pipelined
9144// and multibuffered, enabling the MMA operation to be pipelined.
9245bool mmaHasPipelineableOperands (
0 commit comments