@@ -122,10 +122,10 @@ class StreamPipeliner {
122122public:
123123 StreamPipeliner (scf::ForOp _forOp, int _numStages, int _globalPrefetch,
124124 int _localPrefetch, bool _useAsyncCopy,
125- bool _useF16BlockPingpong)
125+ bool _useF16BlockPingpong, bool _useAsyncCopyOverlap )
126126 : forOp(_forOp), numStages(_numStages), numBuffers(1 ),
127127 useAsyncCopy (_useAsyncCopy), useF16BlockPingpong(_useF16BlockPingpong),
128- schedule(numStages),
128+ useAsyncCopyOverlap(_useAsyncCopyOverlap), schedule(numStages),
129129 axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
130130 int lastStage = numStages - 1 ;
131131 stages[SCHED_GLOBAL_LOAD] = 0 ;
@@ -181,6 +181,9 @@ class StreamPipeliner {
181181 // Whether or not we are intend to ping-pong.
182182 bool useF16BlockPingpong;
183183
184+ // Move AsyncCopy before AsyncWait.
185+ bool useAsyncCopyOverlap;
186+
184187 // Stage for each SchedType Op
185188 int stages[SCHED_SIZE];
186189 // Cluster for each SchedType Op
@@ -297,6 +300,14 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
297300 computeCluster = localLoadCluster;
298301 }
299302
303+ if (useAsyncCopyOverlap) {
304+ globalLoadCluster = 0 ;
305+ localStoreCluster = 1 ;
306+ asyncWaitCluster = 2 ;
307+ localLoadCluster = 3 ;
308+ computeCluster = 3 ;
309+ }
310+
300311 // Make assignments
301312 std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec;
302313 std::generate (clusterVec.begin (), clusterVec.end (),
@@ -1072,6 +1083,9 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10721083 // between MXFP4 and FP16.
10731084 bool useF16BlockPingpong =
10741085 triton::tools::getBoolEnv (" TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG" );
1086+ bool useAsyncCopyOverlap =
1087+ triton::tools::getBoolEnv (" TRITON_HIP_ASYNC_COPY_OVERLAP" ) &
1088+ useAsyncCopy;
10751089 SmallVector<scf::ForOp> loops;
10761090 getOperation ()->walk ([&](scf::ForOp forOp) {
10771091 labelLoadOpsForTritonDot (forOp);
@@ -1092,7 +1106,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10921106 } else {
10931107 StreamPipeliner sp (forOp, tt::getNumStagesOrDefault (forOp, numStages),
10941108 globalPrefetch, localPrefetch, useAsyncCopy,
1095- useF16BlockPingpong);
1109+ useF16BlockPingpong, useAsyncCopyOverlap );
10961110 (void )sp.pipelineLoop ();
10971111 }
10981112 }
0 commit comments