@@ -219,8 +219,9 @@ static void createTMAAsyncCopy(
219219// encodings, raise assertion, since incompatible shared encoding has been
220220// handled in splitLoadsForIncompatible.
221221static std::optional<ttg::SharedEncodingAttr>
222- getSharedEncIfAllUsersAreDotEnc (Value val) {
222+ getSharedEncIfAllUsersAreDotEnc (Value val, bool &incompatible ) {
223223 ttg::SharedEncodingAttr attr;
224+ incompatible = false ;
224225 for (Operation *user : val.getUsers ()) {
225226 ttg::SharedEncodingAttr tempAttr;
226227 if (user->getNumResults () != 1 )
@@ -230,7 +231,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
230231 // First time we find a shared encoding in the chain, save it and try to
231232 // use it if it is compatible with the other users.
232233 tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding ());
233- if (!getSharedEncIfAllUsersAreDotEnc (user->getResult (0 )).has_value ())
234+ if (!getSharedEncIfAllUsersAreDotEnc (user->getResult (0 ), incompatible)
235+ .has_value ())
234236 return std::nullopt ;
235237 } else {
236238 if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
@@ -248,8 +250,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
248250 bitWidth, /* needTrans=*/ false );
249251 }
250252 // Check that the shared encodings needed by the users are compatible.
251- if (attr != nullptr )
252- assert (attr == tempAttr && " incompatible shared encoding" );
253+ if (attr != nullptr && attr != tempAttr) {
254+ incompatible = true ;
255+ return std::nullopt ;
256+ }
253257 attr = tempAttr;
254258 }
255259 return attr;
@@ -439,8 +443,44 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
439443 loadInfo.sharedEncoding =
440444 getSharedEncoding (op, /* loadIsMMAv3=*/ true ).value_or (nullptr );
441445 } else if (auto dot = dyn_cast<tt::DotOp>(use)) {
446+ bool incompatible = false ;
442447 loadInfo.sharedEncoding =
443- getSharedEncIfAllUsersAreDotEnc (op->getResult (0 )).value_or (nullptr );
448+ getSharedEncIfAllUsersAreDotEnc (op->getResult (0 ), incompatible)
449+ .value_or (nullptr );
450+ // If we can't agree on a shared encoding skip pipelinig the load.
451+ if (incompatible)
452+ continue ;
453+
454+ // HACK: Triton LLVM codegen has a bug where local_loads from #shared to
455+ // #mma layout can lead to invalid code if the loaded shape is smaller
456+ // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with
457+ // tile {16,8} is bad because 1 < 8). To work around this, don't
458+ // pipeline such loads.
459+ //
460+ // The codegen bug is caught by an assertion, so if you think you've
461+ // fixed it, feel free to delete this code and see if the assert still
462+ // fails. :)
463+ if (!loadInfo.sharedEncoding ) {
464+ if (auto dotEnc = dyn_cast<ttg::NvidiaMmaEncodingAttr>(
465+ dot.getResult ().getType ().getEncoding ())) {
466+ auto loadTy = cast<RankedTensorType>(op->getResultTypes ()[0 ]);
467+ auto mmaInstrShape = dotEnc.getInstrShape ();
468+ if (loadTy.getRank () < mmaInstrShape.size ())
469+ continue ;
470+ bool ok = true ;
471+ for (int i = 0 ; i < mmaInstrShape.size (); i++) {
472+ if (loadTy.getShape ()[loadTy.getRank () - mmaInstrShape.size () +
473+ i] < mmaInstrShape[i]) {
474+ ok = false ;
475+ break ;
476+ }
477+ }
478+ // If this load might trigger the bug, don't do the fallback logic
479+ // below, which might allow the load to be pipelined.
480+ if (!ok)
481+ continue ;
482+ }
483+ }
444484 }
445485 } else if (auto loadOp = dyn_cast<tt::LoadOp>(use)) {
446486 // The use of this loadOp is another loadOp. If the use is not in the
@@ -476,83 +516,6 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
476516 return loadToInfo;
477517}
478518
479- // Split users to groups, each group has the same shared encoding.
480- // If not all users are Dot encoding, return empty vector.
481- static DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>>
482- handleIncompatibleSharedEncoding (Operation *loadOp) {
483- DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> loadGroups;
484- // Go through transitive uses of the loadOp in the same block.
485- for (Operation *user : loadOp->getUsers ()) {
486- if (user->getBlock () != loadOp->getBlock ())
487- continue ;
488- if (user->getNumResults () != 1 )
489- return loadGroups;
490-
491- ttg::SharedEncodingAttr tempAttr;
492- if (auto memDesc =
493- dyn_cast<triton::MemDescType>(user->getResult (0 ).getType ())) {
494- tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding ());
495- loadGroups[tempAttr].push_back (user);
496- } else {
497- if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
498- return loadGroups;
499- auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
500- cast<TensorOrMemDesc>(user->getResult (0 ).getType ()).getEncoding ());
501- if (!dotOpEnc)
502- return loadGroups;
503- auto srcTy = cast<TensorOrMemDesc>(loadOp->getResult (0 ).getType ());
504- auto CTALayout = ttg::getCTALayout (srcTy.getEncoding ());
505- auto order = ttg::getOrder (srcTy.getEncoding ());
506- unsigned bitWidth = srcTy.getElementType ().getIntOrFloatBitWidth ();
507- tempAttr = ttg::SharedEncodingAttr::get (
508- loadOp->getContext (), dotOpEnc, srcTy.getShape (),
509- ttg::getOrder (srcTy.getEncoding ()),
510- ttg::getCTALayout (srcTy.getEncoding ()),
511- srcTy.getElementType ().getIntOrFloatBitWidth (), /* needTrans=*/ false );
512- loadGroups[tempAttr].push_back (user);
513- }
514- }
515- return loadGroups;
516- }
517-
518- // Clone loads so each group of uses with same shared encoding will have a
519- // corresponding Load.
520- static void splitLoadsForIncompatible (
521- OpBuilder &builder, Operation *loadOp,
522- DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> &lGroups) {
523- // The first group will use the original load, create new loads for other
524- // groups.
525- unsigned idx = 0 ;
526- builder.setInsertionPointAfter (loadOp);
527- for (auto pair : lGroups) {
528- SmallVector<Operation *> &group = pair.second ;
529- if (idx++ == 0 )
530- continue ;
531- Operation *newLoad = builder.clone (*loadOp);
532- for (auto *user : group) {
533- user->replaceUsesOfWith (loadOp->getResult (0 ), newLoad->getResult (0 ));
534- }
535- }
536- }
537-
538- static void splitLoadsWithIncompatibleEncoding (scf::ForOp forOp) {
539- // Get the list of all loads.
540- SmallVector<Operation *> loads;
541- for (Operation &op : forOp.getBody ()->without_terminator ()) {
542- if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
543- loads.push_back (&op);
544- }
545- }
546- OpBuilder builder (forOp);
547- for (auto *loadOp : loads) {
548- auto lGroups = handleIncompatibleSharedEncoding (loadOp);
549- LDBG (" groups with different encoding: " << lGroups.size () << " "
550- << *loadOp);
551- if (lGroups.size () > 1 )
552- splitLoadsForIncompatible (builder, loadOp, lGroups);
553- }
554- }
555-
556519static llvm::MapVector<Operation *, LoadInfo>
557520scheduleLoads (scf::ForOp forOp, tt::CoarseSchedule &schedule,
558521 DenseSet<Operation *> &rootUsers, int numStages) {
@@ -1106,8 +1069,6 @@ static void invalidateBarriers(OpBuilder &builder,
11061069
11071070bool mlir::triton::preProcessLoopAndGetSchedule (
11081071 scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
1109- splitLoadsWithIncompatibleEncoding (forOp);
1110-
11111072 // Schedule the loads and root ops (dot ops) in the loop. This will give us
11121073 // a scaffold for the final schedule.
11131074 DenseSet<Operation *> rootUsers;
0 commit comments