@@ -376,28 +376,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
376376 // completed before we can remove the layoutIsOK check:
377377 // 1. Support for AMD's WMMA
378378 std::function<bool (Attribute)> layoutIsOK = [&](Attribute layout) {
379- if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
380- return !useLegacyMMAConversion;
381- }
382379 if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
383- auto parent = dotOperand.getParent ();
384- if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
385- return false ;
386- }
387- if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
388- if (nvidiaMma.isAmpere ()) {
389- return true ;
390- }
391- }
392- if (isa<AMDMfmaEncodingAttr>(parent)) {
393- return true ;
394- }
395- return false ;
380+ layout = dotOperand.getParent ();
396381 }
397- if (isa<BlockedEncodingAttr>(layout)) {
398- return true ;
382+
383+ if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
384+ return !useLegacyMMAConversion;
399385 }
400- if (isa<LinearEncodingAttr>(layout)) {
386+ if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
401387 return true ;
402388 }
403389 if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -408,6 +394,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
408394 if (!layoutIsOK (srcTy.getEncoding ()) || !layoutIsOK (dstTy.getEncoding ())) {
409395 return failure ();
410396 }
397+ // FIXME [Dot LL] Remove this once we implement this trick in LLs
398+ if (matchMmaV3AndDotOperandLayout (srcTy, dstTy)) {
399+ return failure ();
400+ }
411401
412402 assert (cvtNeedsSharedMemory (srcTy, dstTy));
413403
@@ -498,34 +488,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
498488 // don't need to avoid duplicate writes.
499489 // Input dims: [reg, lane, warp]
500490 // Output dims: [offset, iteration]
501- std::optional<LinearLayout> shmemStoreLayout =
502- chooseStMatrixLayout (ctx, op.getSrc ().getType (), scratchConfig.repShape ,
503- scratchConfig.paddedRepShape , scratchConfig.order ,
504- /* swizzleByteSize=*/ 0 );
505- bool isStMatrix = shmemStoreLayout.has_value ();
506- if (!isStMatrix) {
507- shmemStoreLayout = srcLayout.invertAndCompose (sharedLayout);
508- }
509- assert (shmemStoreLayout.has_value ());
491+ bool isStMatrix = targetInfo.canUseStMatrix (
492+ op.getSrc ().getType (), scratchConfig.repShape ,
493+ scratchConfig.paddedRepShape , scratchConfig.order ,
494+ /* swizzleByteSize=*/ 0 );
495+ LinearLayout shmemStoreLayout =
496+ isStMatrix ? chooseStMatrixLayout (
497+ ctx, op.getSrc ().getType (), scratchConfig.repShape ,
498+ scratchConfig.paddedRepShape , scratchConfig.order ,
499+ /* swizzleByteSize=*/ 0 )
500+ : srcLayout.invertAndCompose (sharedLayout);
510501
511502 const int shmemAllocatedNumElems =
512503 getNumScratchElements (scratchConfig.paddedRepShape );
513- assert (shmemStoreLayout-> getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
504+ assert (shmemStoreLayout. getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
514505
515506 // Layout for the load from shmem to registers.
516507 LinearLayout shmemLoadLayout = dstLayout.invertAndCompose (sharedLayout);
517508
518509 // Check that the `register` fully determines the `iteration`. That is,
519510 // each thread does exactly the same reads and writes to shmem on each
520511 // iteration, just with different input/output registers.
521- assert (shmemStoreLayout-> sublayoutIsZero ({ kLane , kWarp , kBlock },
522- {kIteration }));
512+ assert (
513+ shmemStoreLayout. sublayoutIsZero ({ kLane , kWarp , kBlock }, {kIteration }));
523514 assert (
524515 shmemLoadLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
525516
526517 // iteration -> registers
527518 SmallVector<SmallVector<int >> inRegsForIter =
528- collectRegsForIter (ctx, * shmemStoreLayout);
519+ collectRegsForIter (ctx, shmemStoreLayout);
529520 SmallVector<SmallVector<int >> outRegsForIter =
530521 collectRegsForIter (ctx, shmemLoadLayout);
531522
@@ -582,7 +573,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
582573 return vecAddr;
583574 };
584575
585- auto storeBase = applyLinearLayout (loc, rewriter, * shmemStoreLayout,
576+ auto storeBase = applyLinearLayout (loc, rewriter, shmemStoreLayout,
586577 {{kRegister , i32_val (0 )},
587578 {kLane , laneId},
588579 {kWarp , warpId},
@@ -605,11 +596,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
605596
606597 // When using `stmatrix`, we can store `inVec` elements even if they are
607598 // not contiguous
608- auto inVec = isStMatrix ? shmemStoreLayout-> getNumConsecutiveInOut ()
599+ auto inVec = isStMatrix ? shmemStoreLayout. getNumConsecutiveInOut ()
609600 : scratchConfig.inVec ;
610601 for (int j = 0 ; j < inVals.size () / iterations; j += inVec) {
611602 auto inRegSlice = inRegs[j];
612- Value vecAddr = getVecAddr (* shmemStoreLayout, storeBase, inRegSlice);
603+ Value vecAddr = getVecAddr (shmemStoreLayout, storeBase, inRegSlice);
613604 SmallVector<Value> inValsVec;
614605 for (int k = 0 ; k < inVec; k++)
615606 inValsVec.push_back (inVals[inRegSlice + k]);
0 commit comments