2323#include " triton/Tools/LayoutUtils.h"
2424#include " triton/Tools/LinearLayout.h"
2525#include " triton/Tools/StrUtil.h"
26- #include " triton/Tools/Sys/GetEnv.hpp"
2726#include " llvm/ADT/SmallSet.h"
2827#include " llvm/ADT/TypeSwitch.h"
2928#include " llvm/Support/MathExtras.h"
@@ -428,6 +427,15 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
428427 return encoding;
429428}
430429
430+ bool isSplitCompatible (MLIRContext *ctx, const LinearLayout &ll) {
431+ auto lastDim = ll.getNumOutDims () - 1 ;
432+ auto kReg = StringAttr::get (ctx, " register" );
433+ auto kLastDim = StringAttr::get (ctx, " dim" + std::to_string (lastDim));
434+ auto sublayout =
435+ ll.sublayout ({kReg }, {kLastDim }).removeZeroBasesAlongDim (kReg );
436+ return sublayout == LinearLayout::identity1D (2 , kReg , kLastDim );
437+ }
438+
431439LogicalResult tryJoinOnAxis (MLIRContext *ctx, const LinearLayout &inLl,
432440 LinearLayout &outLl, bool fwdInference, int axis,
433441 std::optional<Location> loc) {
@@ -1331,8 +1339,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13311339 if (parser.parseGreater ().failed ())
13321340 return {};
13331341
1334- unsigned versionMajor = 0 ;
1335- unsigned versionMinor = 0 ;
1342+ unsigned version = 0 ;
13361343 SmallVector<unsigned > warpsPerCTA;
13371344 SmallVector<unsigned > instrShape;
13381345 bool isTransposed;
@@ -1341,12 +1348,8 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13411348 std::optional<SmallVector<unsigned >> CTAOrder;
13421349
13431350 for (const NamedAttribute &attr : dict) {
1344- if (attr.getName () == " versionMajor" ) {
1345- if (parseUInt (parser, attr, versionMajor, " versionMajor" ).failed ())
1346- return {};
1347- }
1348- if (attr.getName () == " versionMinor" ) {
1349- if (parseUInt (parser, attr, versionMinor, " versionMinor" ).failed ())
1351+ if (attr.getName () == " version" ) {
1352+ if (parseUInt (parser, attr, version, " verison" ).failed ())
13501353 return {};
13511354 }
13521355 if (attr.getName () == " warpsPerCTA" ) {
@@ -1385,14 +1388,13 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13851388 return {};
13861389
13871390 return parser.getChecked <AMDMfmaEncodingAttr>(
1388- parser.getContext (), versionMajor, versionMinor, warpsPerCTA ,
1389- instrShape[ 0 ], instrShape[ 1 ], isTransposed, *CTALayout);
1391+ parser.getContext (), version, warpsPerCTA, instrShape[ 0 ], instrShape[ 1 ] ,
1392+ isTransposed, *CTALayout);
13901393}
13911394
13921395void AMDMfmaEncodingAttr::print (AsmPrinter &printer) const {
13931396 printer << " <{"
1394- << " versionMajor = " << getVersionMajor () //
1395- << " , versionMinor = " << getVersionMinor () //
1397+ << " version = " << getVersion () //
13961398 << " , warpsPerCTA = [" << getWarpsPerCTA () << " ]" //
13971399 << " , instrShape = [" << ArrayRef{getMDim (), getNDim ()} << " ]" //
13981400 << " , isTransposed = " << getIsTransposed ();
@@ -1401,17 +1403,12 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
14011403 printer << " }>" ;
14021404}
14031405
1404- LogicalResult
1405- AMDMfmaEncodingAttr::verify (function_ref<mlir::InFlightDiagnostic()> emitError,
1406- unsigned versionMajor, unsigned versionMinor,
1407- llvm::ArrayRef<unsigned int> warpsPerCTA,
1408- unsigned mDim, unsigned nDim, bool isTransposed,
1409- mlir::triton::gpu::CTALayoutAttr) {
1410- if (!(versionMajor >= 0 && versionMajor <= 4 )) {
1411- return emitError () << " major version must be in the [0, 4] range" ;
1412- }
1413- if (versionMinor != 0 ) {
1414- return emitError () << " minor version must be 0" ;
1406+ LogicalResult AMDMfmaEncodingAttr::verify (
1407+ function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
1408+ llvm::ArrayRef<unsigned int> warpsPerCTA, unsigned mDim, unsigned nDim,
1409+ bool isTransposed, mlir::triton::gpu::CTALayoutAttr) {
1410+ if (!(version >= 0 && version <= 4 )) {
1411+ return emitError () << " version must be in the [0, 4] range" ;
14151412 }
14161413 if (!((mDim == 32 && nDim == 32 ) || (mDim == 16 && nDim == 16 ))) {
14171414 return emitError ()
@@ -1965,7 +1962,7 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
19651962 bool isKContig = sharedOrder[0 ] == kDimIndex ;
19661963 // GFX950 supports LDS transpose load instructions, so we need swizzling even
19671964 // when K dimension is not the contiguous dimension.
1968- bool isGFX950 = getVersionMajor () == 4 ;
1965+ bool isGFX950 = getVersion () == 4 ;
19691966 bool swizzleNonKContig =
19701967 isGFX950 && (elemBitWidth == 8 || elemBitWidth == 16 );
19711968
@@ -2654,7 +2651,19 @@ struct TritonGPUInferLayoutInterface
26542651 inferDefaultJoinOpEncoding (Attribute srcEnc, Attribute &dstEnc,
26552652 ArrayRef<int64_t > shape,
26562653 std::optional<Location> loc) const override {
2657- if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
2654+ auto ctx = getContext ();
2655+ if (auto enc = mlir::dyn_cast<SliceEncodingAttr>(srcEnc);
2656+ enc && enc.getDim () == shape.size ()) {
2657+ SmallVector<int64_t > joinedShape (shape);
2658+ joinedShape.push_back (2 );
2659+ auto parent = enc.getParent ();
2660+ auto parentLL = toLinearLayout (joinedShape, parent);
2661+
2662+ if (isSplitCompatible (ctx, parentLL)) {
2663+ dstEnc = parent;
2664+ return success ();
2665+ }
2666+ } else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
26582667 // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
26592668 // AxBxCx2. The encoding is the same as the input, but with 2 elems per
26602669 // thread in the new dimension. The new dimension is the fastest running
@@ -2679,8 +2688,6 @@ struct TritonGPUInferLayoutInterface
26792688 return success ();
26802689 }
26812690
2682- auto ctx = getContext ();
2683-
26842691 // Append dim to shape
26852692 auto ll = toLinearLayout (shape, srcEnc);
26862693 SmallVector<int64_t > dstShape (shape.begin (), shape.end ());
@@ -2757,7 +2764,6 @@ struct TritonGPUInferLayoutInterface
27572764 if (!result.succeeded ()) {
27582765 return failure ();
27592766 }
2760-
27612767 // Remove last dim from newLl (which should be 1)
27622768 SmallVector<int64_t > dstShape (shape.begin (), shape.end ());
27632769 dstShape.pop_back ();
0 commit comments