@@ -20,11 +20,37 @@ namespace triton::nvidia_gpu {
2020// Given an MMAv5 operation in a loop, determine if its accumulator can be
2121// multibuffered.
2222bool isAccMultibufferingPossible (MMAv5OpInterface mma, scf::ForOp forOp);
23- // Only pipeline the loops where the MMA happens before the tmem_load, or is in
24- // the same stage as the tmem_load. Lowering does not support the case where the
25- // MMA is in a different stage as the tmem_load and happens after it.
26- bool mmav5DominatesTmemLoads (
27- scf::ForOp forOp, function_ref<bool (MMAv5OpInterface)> isMmaPipelineable);
23+
24+ // Returns true if the MMA operation requires acc multi-buffering when
25+ // pipelined.
26+ bool requiresAccMultiBuffering (MMAv5OpInterface mma, scf::ForOp forOp);
27+
28+ // Returns true if there are loads from tmem after the MMA operation.
29+ bool hasLoadsAfterMMA (MMAv5OpInterface mma, scf::ForOp forOp);
30+
31+ // Helper class to determine if the operands of an MMA operation are
32+ // pipelineable.
33+ class MMAv5PipelineableOperandsHelper {
34+ public:
35+ MMAv5PipelineableOperandsHelper (
36+ MMAv5OpInterface mmaOp, scf::ForOp forOp,
37+ std::function<bool (Operation *)> isLoadToBePipelined)
38+ : mmaOp(mmaOp), forOp(forOp), isLoadToBePipelined(isLoadToBePipelined) {
39+ run ();
40+ }
41+ bool isPipelineable = false ;
42+ // If true, the existing operand loads are all been found and their
43+ // pipelineability has been determined.
44+ bool isOperandsStateDetermined = false ;
45+ SmallVector<Operation *> unpipelineableOperandLoads;
46+
47+ private:
48+ MMAv5OpInterface mmaOp;
49+ scf::ForOp forOp;
50+ std::function<bool (Operation *)> isLoadToBePipelined;
51+ bool comesFromLoadOrOutsideLoop (Value v, Operation *&foundLoad);
52+ void run ();
53+ };
2854
2955// ===----------------------------------------------------------------------===//
3056// MMA Pipeline Rewriters
@@ -35,12 +61,6 @@ bool mmav5DominatesTmemLoads(
3561TMEMAllocOp createTMemAlloc (OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
3662 bool multiBufferred, int numStages);
3763
38- // Return true if operands of the MMA operation are/are going to be pipelined
39- // and multibuffered, enabling the MMA operation to be pipelined.
40- bool mmaHasPipelineableOperands (
41- MMAv5OpInterface mma, scf::ForOp forOp,
42- std::function<bool (Operation *)> isLoadPipelineable);
43-
4464// Return true if the accumulator of an mma in subsequent iterations is either
4565// independent from the previous iteration (overwritten) or completely reused,
4666// without read-modify-write.
0 commit comments