@@ -494,6 +494,46 @@ LogicalResult MemDescReshapeOp::verify() {
494
494
return success ();
495
495
}
496
496
497
+ static LogicalResult inferMemDescReshapeOpEncoding (ArrayRef<int64_t > srcShape,
498
+ Attribute srcEnc,
499
+ ArrayRef<int64_t > dstShape,
500
+ Attribute &dstEnc) {
501
+ if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
502
+ // TODO: supporting reshape of CTA layouts is non-trivial.
503
+ if (getNumCTAs (mmaEncoding) > 1 )
504
+ return failure ();
505
+ int innerDimDst =
506
+ mmaEncoding.getTransposed () ? dstShape.front () : dstShape.back ();
507
+ int innerDimSrc =
508
+ mmaEncoding.getTransposed () ? srcShape.front () : srcShape.back ();
509
+ // For now disallow reshape of the inner dimension.
510
+ if (innerDimDst != innerDimSrc)
511
+ return failure ();
512
+ auto *ctx = srcEnc.getContext ();
513
+
514
+ // CTALayout can be all 1's because we bailed on multi-CTA layouts above.
515
+ auto CTALayout = CTALayoutAttr::get (
516
+ ctx,
517
+ /* CTAsPerCGA=*/ SmallVector<unsigned >(dstShape.size (), 1 ),
518
+ /* CTASplitNum=*/ SmallVector<unsigned >(dstShape.size (), 1 ),
519
+ /* CTAOrder=*/ llvm::to_vector (llvm::seq<unsigned >(dstShape.size ())));
520
+ dstEnc = NVMMASharedEncodingAttr::get (
521
+ ctx, mmaEncoding.getSwizzlingByteWidth (), mmaEncoding.getTransposed (),
522
+ mmaEncoding.getElementBitWidth (), mmaEncoding.getFp4Padded (),
523
+ CTALayout);
524
+ // Big guns, check linear layouts are equivalent
525
+ // We disallow reshaping memdesc_subviews in the verifier
526
+ // We disallow reshaping memdesc_subviews in the verifier
527
+ auto srcLL = toLinearLayout (srcShape, srcEnc, srcShape);
528
+ auto dstLL = toLinearLayout (dstShape, dstEnc, dstShape);
529
+ if (reshapeLayout (ctx, srcLL, dstShape) != dstLL) {
530
+ return failure ();
531
+ }
532
+ return success ();
533
+ }
534
+ return failure ();
535
+ }
536
+
497
537
LogicalResult MemDescReshapeOp::inferReturnTypes (
498
538
MLIRContext *context, std::optional<Location> loc, MemDescType srcTy,
499
539
ArrayRef<int64_t > dstShape, MemDescType &inferredReturnType) {
@@ -503,9 +543,8 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
503
543
504
544
Attribute dstEncoding;
505
545
if (Attribute srcEnc = srcTy.getEncoding ()) {
506
- auto *inferLayout = cast<DialectInferLayoutInterface>(&srcEnc.getDialect ());
507
- if (failed (inferLayout->inferReshapeOpEncoding (srcTy.getShape (), srcEnc,
508
- dstShape, dstEncoding, loc)))
546
+ if (failed (inferMemDescReshapeOpEncoding (srcTy.getShape (), srcEnc, dstShape,
547
+ dstEncoding)))
509
548
return failure ();
510
549
}
511
550
0 commit comments