@@ -1927,6 +1927,7 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19271927
19281928 unsigned swizzlingByteWidth;
19291929 bool transposed;
1930+ unsigned elementBitWidth;
19301931 std::optional<SmallVector<unsigned >> CTAsPerCGA;
19311932 std::optional<SmallVector<unsigned >> CTASplitNum;
19321933 std::optional<SmallVector<unsigned >> CTAOrder;
@@ -1938,6 +1939,9 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19381939 } else if (attr.getName () == " transposed" ) {
19391940 if (parseBool (parser, attr, transposed, " transposed" ).failed ())
19401941 return {};
1942+ } else if (attr.getName () == " elementBitWidth" ) {
1943+ if (parseUInt (parser, attr, elementBitWidth, " elementBitWidth" ).failed ())
1944+ return {};
19411945 } else if (attr.getName () == " CTAsPerCGA" ) {
19421946 if (parseIntArrayAttr (parser, attr, CTAsPerCGA.emplace (), " CTAsPerCGA" )
19431947 .failed ())
@@ -1963,13 +1967,15 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19631967 return {};
19641968
19651969 return parser.getChecked <NVMMASharedEncodingAttr>(
1966- parser.getContext (), swizzlingByteWidth, transposed, *CTALayout);
1970+ parser.getContext (), swizzlingByteWidth, transposed, elementBitWidth,
1971+ *CTALayout);
19671972}
19681973
19691974void NVMMASharedEncodingAttr::print (AsmPrinter &printer) const {
19701975 printer << " <{"
19711976 << " swizzlingByteWidth = " << getSwizzlingByteWidth () //
1972- << " , transposed = " << getTransposed ();
1977+ << " , transposed = " << getTransposed () //
1978+ << " , elementBitWidth = " << getElementBitWidth ();
19731979 maybePrintCTALayout (getContext (), printer, getCTALayout (),
19741980 /* rank=*/ 2 );
19751981 printer << " }>" ;
@@ -2611,7 +2617,8 @@ struct TritonGPUInferLayoutInterface
26112617 return failure ();
26122618 }
26132619 resultEncoding = NVMMASharedEncodingAttr::get (
2614- ctx, enc.getSwizzlingByteWidth (), !enc.getTransposed (), *ctaLayout);
2620+ ctx, enc.getSwizzlingByteWidth (), !enc.getTransposed (),
2621+ enc.getElementBitWidth (), *ctaLayout);
26152622 return success ();
26162623 }
26172624
0 commit comments