@@ -380,24 +380,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
380380 return !useLegacyMMAConversion;
381381 }
382382 if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
383- auto parent = dotOperand.getParent ();
384- if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
385- return false ;
386- }
387- if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
388- if (nvidiaMma.isAmpere ()) {
389- return true ;
390- }
391- }
392- if (isa<AMDMfmaEncodingAttr>(parent)) {
393- return true ;
383+ if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(
384+ dotOperand.getParent ())) {
385+ return !useLegacyMMAConversion;
394386 }
395387 return false ;
396388 }
397- if (isa<BlockedEncodingAttr>(layout)) {
398- return true ;
399- }
400- if (isa<LinearEncodingAttr>(layout)) {
389+ if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
401390 return true ;
402391 }
403392 if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -408,6 +397,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
408397 if (!layoutIsOK (srcTy.getEncoding ()) || !layoutIsOK (dstTy.getEncoding ())) {
409398 return failure ();
410399 }
400+ // FIXME [Dot LL] Remove this once we implement this trick in LLs
401+ if (matchMmaV3AndDotOperandLayout (srcTy, dstTy)) {
402+ return failure ();
403+ }
411404
412405 assert (cvtNeedsSharedMemory (srcTy, dstTy));
413406
@@ -498,34 +491,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
498491 // don't need to avoid duplicate writes.
499492 // Input dims: [reg, lane, warp]
500493 // Output dims: [offset, iteration]
501- std::optional<LinearLayout> shmemStoreLayout =
502- chooseStMatrixLayout (ctx, op.getSrc ().getType (), scratchConfig.repShape ,
503- scratchConfig.paddedRepShape , scratchConfig.order ,
504- /* swizzleByteSize=*/ 0 );
505- bool isStMatrix = shmemStoreLayout.has_value ();
506- if (!isStMatrix) {
507- shmemStoreLayout = srcLayout.invertAndCompose (sharedLayout);
508- }
509- assert (shmemStoreLayout.has_value ());
494+ bool isStMatrix = targetInfo.canUseStMatrix (
495+ op.getSrc ().getType (), scratchConfig.repShape ,
496+ scratchConfig.paddedRepShape , scratchConfig.order ,
497+ /* swizzleByteSize=*/ 0 );
498+ LinearLayout shmemStoreLayout =
499+ isStMatrix ? chooseStMatrixLayout (
500+ ctx, op.getSrc ().getType (), scratchConfig.repShape ,
501+ scratchConfig.paddedRepShape , scratchConfig.order ,
502+ /* swizzleByteSize=*/ 0 )
503+ : srcLayout.invertAndCompose (sharedLayout);
510504
511505 const int shmemAllocatedNumElems =
512506 getNumScratchElements (scratchConfig.paddedRepShape );
513- assert (shmemStoreLayout-> getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
507+ assert (shmemStoreLayout. getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
514508
515509 // Layout for the load from shmem to registers.
516510 LinearLayout shmemLoadLayout = dstLayout.invertAndCompose (sharedLayout);
517511
518512 // Check that the `register` fully determines the `iteration`. That is,
519513 // each thread does exactly the same reads and writes to shmem on each
520514 // iteration, just with different input/output registers.
521- assert (shmemStoreLayout-> sublayoutIsZero ({ kLane , kWarp , kBlock },
522- {kIteration }));
515+ assert (
516+ shmemStoreLayout. sublayoutIsZero ({ kLane , kWarp , kBlock }, {kIteration }));
523517 assert (
524518 shmemLoadLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
525519
526520 // iteration -> registers
527521 SmallVector<SmallVector<int >> inRegsForIter =
528- collectRegsForIter (ctx, * shmemStoreLayout);
522+ collectRegsForIter (ctx, shmemStoreLayout);
529523 SmallVector<SmallVector<int >> outRegsForIter =
530524 collectRegsForIter (ctx, shmemLoadLayout);
531525
@@ -582,7 +576,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
582576 return vecAddr;
583577 };
584578
585- auto storeBase = applyLinearLayout (loc, rewriter, * shmemStoreLayout,
579+ auto storeBase = applyLinearLayout (loc, rewriter, shmemStoreLayout,
586580 {{kRegister , i32_val (0 )},
587581 {kLane , laneId},
588582 {kWarp , warpId},
@@ -605,11 +599,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
605599
606600 // When using `stmatrix`, we can store `inVec` elements even if they are
607601 // not contiguous
608- auto inVec = isStMatrix ? shmemStoreLayout-> getNumConsecutiveInOut ()
602+ auto inVec = isStMatrix ? shmemStoreLayout. getNumConsecutiveInOut ()
609603 : scratchConfig.inVec ;
610604 for (int j = 0 ; j < inVals.size () / iterations; j += inVec) {
611605 auto inRegSlice = inRegs[j];
612- Value vecAddr = getVecAddr (* shmemStoreLayout, storeBase, inRegSlice);
606+ Value vecAddr = getVecAddr (shmemStoreLayout, storeBase, inRegSlice);
613607 SmallVector<Value> inValsVec;
614608 for (int k = 0 ; k < inVec; k++)
615609 inValsVec.push_back (inVals[inRegSlice + k]);
0 commit comments