@@ -531,40 +531,44 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
531531 Attribute srcEnc,
532532 ArrayRef<int64_t > dstShape,
533533 Attribute &dstEnc) {
534+ // TODO Delete this once SharedLinearEncodingAttr is more widely supported.
534535 if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
535- // TODO: supporting reshape of CTA layouts is non-trivial.
536- if (getNumCTAs (mmaEncoding) > 1 )
537- return failure ();
538- int innerDimDst =
539- mmaEncoding.getTransposed () ? dstShape.front () : dstShape.back ();
540- int innerDimSrc =
541- mmaEncoding.getTransposed () ? srcShape.front () : srcShape.back ();
542- // For now disallow reshape of the inner dimension.
543- if (innerDimDst != innerDimSrc)
544- return failure ();
545536 auto *ctx = srcEnc.getContext ();
546-
547- // CTALayout can be all 1's because we bailed on multi-CTA layouts above.
548- auto CTALayout = CTALayoutAttr::get (
549- ctx,
550- /* CTAsPerCGA=*/ SmallVector<unsigned >(dstShape.size (), 1 ),
551- /* CTASplitNum=*/ SmallVector<unsigned >(dstShape.size (), 1 ),
552- /* CTAOrder=*/ llvm::to_vector (llvm::seq<unsigned >(dstShape.size ())));
553- dstEnc = NVMMASharedEncodingAttr::get (
554- ctx, mmaEncoding.getSwizzlingByteWidth (), mmaEncoding.getTransposed (),
555- mmaEncoding.getElementBitWidth (), mmaEncoding.getFp4Padded (),
556- CTALayout);
557- // Big guns, check linear layouts are equivalent
558- // We disallow reshaping memdesc_subslice in the verifier
559- // so allocShape == shape
560- auto srcLL = toLinearLayout (srcShape, srcEnc);
561- auto dstLL = toLinearLayout (dstShape, dstEnc);
562- if (reshapeLayout (ctx, srcLL, dstShape) != dstLL) {
563- return failure ();
537+ if (getNumCTAs (mmaEncoding) == 1 ) {
538+ int innerDimDst =
539+ mmaEncoding.getTransposed () ? dstShape.front () : dstShape.back ();
540+ int innerDimSrc =
541+ mmaEncoding.getTransposed () ? srcShape.front () : srcShape.back ();
542+ // We can keep an NVMMAShared encoding only if the innermost dimension is
543+ // preserved. Otherwise fall back to the generic shared-linear encoding
544+ // logic below.
545+ if (innerDimDst == innerDimSrc) {
546+ auto CTALayout = CTALayoutAttr::get (
547+ ctx,
548+ /* CTAsPerCGA=*/ SmallVector<unsigned >(dstShape.size (), 1 ),
549+ /* CTASplitNum=*/ SmallVector<unsigned >(dstShape.size (), 1 ),
550+ /* CTAOrder=*/ llvm::to_vector (llvm::seq<unsigned >(dstShape.size ())));
551+ auto candidateEncoding = NVMMASharedEncodingAttr::get (
552+ ctx, mmaEncoding.getSwizzlingByteWidth (),
553+ mmaEncoding.getTransposed (), mmaEncoding.getElementBitWidth (),
554+ mmaEncoding.getFp4Padded (), CTALayout);
555+ auto srcLL = toLinearLayout (srcShape, srcEnc);
556+ auto dstLL = toLinearLayout (dstShape, candidateEncoding);
557+ if (reshapeLayout (ctx, srcLL, dstShape) == dstLL) {
558+ dstEnc = candidateEncoding;
559+ return success ();
560+ }
561+ }
564562 }
565- return success ();
566563 }
567- return failure ();
564+
565+ // Generic LL case
566+ auto sharedEnc = cast<SharedEncodingTrait>(srcEnc);
567+ auto *ctx = srcEnc.getContext ();
568+ auto srcLL = toLinearLayout (srcShape, srcEnc);
569+ auto dstLL = reshapeLayout (ctx, srcLL, dstShape);
570+ dstEnc = SharedLinearEncodingAttr::get (ctx, dstLL, sharedEnc.getAlignment ());
571+ return success ();
568572}
569573
570574LogicalResult MemDescReshapeOp::inferReturnTypes (
0 commit comments