@@ -216,11 +216,11 @@ static void createTMAAsyncCopy(
216216// If all the transitive uses of the given value have are used by a convert to
217217// the same dot operand encoding, return the shared encoding that needs to be
218218// used to be compatible with users' layouts. If there are imcompatible shared
219- // encodings set `incompatible` to true.
219+ // encodings, raise assertion, since incompatible shared encoding has been
220+ // handled in splitLoadsForIncompatible.
220221static std::optional<ttg::SharedEncodingAttr>
221- getSharedEncIfAllUsersAreDotEnc (Value val, bool &incompatible ) {
222+ getSharedEncIfAllUsersAreDotEnc (Value val) {
222223 ttg::SharedEncodingAttr attr;
223- incompatible = false ;
224224 for (Operation *user : val.getUsers ()) {
225225 ttg::SharedEncodingAttr tempAttr;
226226 if (user->getNumResults () != 1 )
@@ -230,8 +230,7 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
230230 // First time we find a shared encoding in the chain, save it and try to
231231 // use it if it is compatible with the other users.
232232 tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding ());
233- if (!getSharedEncIfAllUsersAreDotEnc (user->getResult (0 ), incompatible)
234- .has_value ())
233+ if (!getSharedEncIfAllUsersAreDotEnc (user->getResult (0 )).has_value ())
235234 return std::nullopt ;
236235 } else {
237236 if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
@@ -249,10 +248,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
249248 bitWidth, /* needTrans=*/ false );
250249 }
251250 // Check that the shared encodings needed by the users are compatible.
252- if (attr != nullptr && attr != tempAttr) {
253- incompatible = true ;
254- return std::nullopt ;
255- }
251+ if (attr != nullptr )
252+ assert (attr == tempAttr && " incompatible shared encoding" );
256253 attr = tempAttr;
257254 }
258255 return attr;
@@ -442,13 +439,9 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
442439 loadInfo.sharedEncoding =
443440 getSharedEncoding (op, /* loadIsMMAv3=*/ true ).value_or (nullptr );
444441 } else if (auto dot = dyn_cast<tt::DotOp>(use)) {
445- bool incompatible = false ;
446442 loadInfo.sharedEncoding =
447- getSharedEncIfAllUsersAreDotEnc (op->getResult (0 ), incompatible)
448- .value_or (nullptr );
449- // If we can't agree on a shared encoding skip pipelinig the load.
450- if (incompatible)
451- continue ;
443+ getSharedEncIfAllUsersAreDotEnc (op->getResult (0 )).value_or (nullptr );
444+
452445 // HACK: Triton LLVM codegen has a bug where local_loads from #shared to
453446 // #mma layout can lead to invalid code if the loaded shape is smaller
454447 // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with
@@ -514,9 +507,87 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
514507 return loadToInfo;
515508}
516509
510+ // Split users to groups, each group has the same shared encoding.
511+ // If not all users are Dot encoding, return empty vector.
512+ static DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>>
513+ handleIncompatibleSharedEncoding (Operation *loadOp) {
514+ DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> loadGroups;
515+ // Go through transitive uses of the loadOp in the same block.
516+ for (Operation *user : loadOp->getUsers ()) {
517+ if (user->getBlock () != loadOp->getBlock ())
518+ continue ;
519+ if (user->getNumResults () != 1 )
520+ return loadGroups;
521+
522+ ttg::SharedEncodingAttr tempAttr;
523+ if (auto memDesc =
524+ dyn_cast<triton::MemDescType>(user->getResult (0 ).getType ())) {
525+ tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding ());
526+ loadGroups[tempAttr].push_back (user);
527+ } else {
528+ if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
529+ return loadGroups;
530+ auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
531+ cast<TensorOrMemDesc>(user->getResult (0 ).getType ()).getEncoding ());
532+ if (!dotOpEnc)
533+ return loadGroups;
534+ auto srcTy = cast<TensorOrMemDesc>(loadOp->getResult (0 ).getType ());
535+ auto CTALayout = ttg::getCTALayout (srcTy.getEncoding ());
536+ auto order = ttg::getOrder (srcTy.getEncoding ());
537+ unsigned bitWidth = srcTy.getElementType ().getIntOrFloatBitWidth ();
538+ tempAttr = ttg::SharedEncodingAttr::get (
539+ loadOp->getContext (), dotOpEnc, srcTy.getShape (),
540+ ttg::getOrder (srcTy.getEncoding ()),
541+ ttg::getCTALayout (srcTy.getEncoding ()),
542+ srcTy.getElementType ().getIntOrFloatBitWidth (), /* needTrans=*/ false );
543+ loadGroups[tempAttr].push_back (user);
544+ }
545+ }
546+ return loadGroups;
547+ }
548+
549+ // Clone loads so each group of uses with same shared encoding will have a
550+ // corresponding Load.
551+ static void splitLoadsForIncompatible (
552+ OpBuilder &builder, Operation *loadOp,
553+ DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> &lGroups) {
554+ // The first group will use the original load, create new loads for other
555+ // groups.
556+ unsigned idx = 0 ;
557+ builder.setInsertionPointAfter (loadOp);
558+ for (auto pair : lGroups) {
559+ SmallVector<Operation *> &group = pair.second ;
560+ if (idx++ == 0 )
561+ continue ;
562+ Operation *newLoad = builder.clone (*loadOp);
563+ for (auto *user : group) {
564+ user->replaceUsesOfWith (loadOp->getResult (0 ), newLoad->getResult (0 ));
565+ }
566+ }
567+ }
568+
569+ static void splitLoadsWithIncompatibleEncoding (scf::ForOp forOp) {
570+ // Get the list of all loads.
571+ SmallVector<Operation *> loads;
572+ for (Operation &op : forOp.getBody ()->without_terminator ()) {
573+ if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
574+ loads.push_back (&op);
575+ }
576+ }
577+ OpBuilder builder (forOp);
578+ for (auto *loadOp : loads) {
579+ auto lGroups = handleIncompatibleSharedEncoding (loadOp);
580+ LDBG (" groups with different encoding: " << lGroups.size () << " "
581+ << *loadOp);
582+ if (lGroups.size () > 1 )
583+ splitLoadsForIncompatible (builder, loadOp, lGroups);
584+ }
585+ }
586+
517587static llvm::MapVector<Operation *, LoadInfo>
518588scheduleLoads (scf::ForOp forOp, tt::CoarseSchedule &schedule,
519589 DenseSet<Operation *> &rootUsers, int numStages) {
590+
520591 ModuleOp moduleOp = forOp->getParentOfType <ModuleOp>();
521592 tt::ModuleAxisInfoAnalysis axisInfoAnalysis (moduleOp);
522593
@@ -1054,6 +1125,8 @@ static void invalidateBarriers(OpBuilder &builder,
10541125
10551126bool mlir::triton::preProcessLoopAndGetSchedule (
10561127 scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
1128+ splitLoadsWithIncompatibleEncoding (forOp);
1129+
10571130 // Schedule the loads and root ops (dot ops) in the loop. This will give us
10581131 // a scaffold for the final schedule.
10591132 DenseSet<Operation *> rootUsers;
0 commit comments