@@ -64,11 +64,7 @@ class OpBuilderWithStage : public OpBuilder {
6464 OpTy createWithStage (Location location, int stage, int cluster,
6565 Args &&...args) {
6666 OpTy op = OpBuilder::create<OpTy>(location, std::forward<Args>(args)...);
67- auto ctx = getContext ();
68- op->setAttr (mlir::triton::kLoopStageAttrName ,
69- IntegerAttr::get (IntegerType::get (ctx, 32 ), stage));
70- op->setAttr (mlir::triton::kLoopClusterAttrName ,
71- IntegerAttr::get (IntegerType::get (ctx, 32 ), cluster));
67+ tt::setStageCluster (op, stage, cluster);
7268 return op;
7369 }
7470 using OpBuilder::create;
@@ -204,9 +200,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
204200 // Prefetch load if is not MMAV3 and is used by the dot.
205201 if (loadToInfo[loadOp].usedByDot ) {
206202 assert (stageForFirstUse >= 1 );
207- tt::setStageCluster (forOp, wait, stageForFirstUse - 1 , maxClusterId + 1 );
208- tt::setStageCluster (forOp, viewLoad, stageForFirstUse - 1 ,
209- maxClusterId + 1 );
203+ tt::setStageCluster (wait, stageForFirstUse - 1 , maxClusterId + 1 );
204+ tt::setStageCluster (viewLoad, stageForFirstUse - 1 , maxClusterId + 1 );
210205 retCode = stageForFirstUse - 1 ;
211206 }
212207 }
0 commit comments