@@ -427,15 +427,6 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
427427 return encoding;
428428}
429429
430- bool isSplitCompatible (MLIRContext *ctx, const LinearLayout &ll) {
431- auto lastDim = ll.getNumOutDims () - 1 ;
432- auto kReg = StringAttr::get (ctx, " register" );
433- auto kLastDim = StringAttr::get (ctx, " dim" + std::to_string (lastDim));
434- auto sublayout =
435- ll.sublayout ({kReg }, {kLastDim }).removeZeroBasesAlongDim (kReg );
436- return sublayout == LinearLayout::identity1D (2 , kReg , kLastDim );
437- }
438-
439430LogicalResult tryJoinOnAxis (MLIRContext *ctx, const LinearLayout &inLl,
440431 LinearLayout &outLl, bool fwdInference, int axis,
441432 std::optional<Location> loc) {
@@ -2056,6 +2047,42 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
20562047 return {16 , 16 , 16 };
20572048}
20582049
2050+ SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand (
2051+ CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t > operandShape,
2052+ ArrayRef<unsigned > sharedOrder, unsigned kWidth , unsigned elemBitWidth,
2053+ bool needTrans) const {
2054+ int kDimIndex = operandIdx == 0 ? 1 : 0 ;
2055+ bool isKContig = sharedOrder[0 ] == kDimIndex ;
2056+
2057+ if (!isKContig) {
2058+ // Do not swizzle. In this case accesses will go in different banks even
2059+ // without swizzling.
2060+ return SwizzledSharedEncodingAttr::get (getContext (), 1 , 1 , 1 , sharedOrder,
2061+ ctaLayout);
2062+ }
2063+
2064+ // max vectorization size for ds_load is 128 bits
2065+ int vectorSize = std::min (kWidth * elemBitWidth, 128u ) / elemBitWidth;
2066+
2067+ const int numBanks = 32 ;
2068+ const int bankBitWidth = 32 ;
2069+
2070+ // Number of inner dimension rows per one pattern repeat
2071+ int innerDimLength = operandShape[sharedOrder[0 ]];
2072+ int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth;
2073+
2074+ int perPhase = std::max (1 , elemsPerOneBanksRow / innerDimLength);
2075+ // for both RDNA3 and RDNA4, the M/N dimension of wmma is 16
2076+ // This represents the max number of rows that can be accessed
2077+ // at the same time
2078+ int mDim = getMNKDimPerInstr ()[0 ];
2079+ int maxPhase =
2080+ std::max (std::min (mDim / perPhase, innerDimLength / vectorSize), 1 );
2081+
2082+ return SwizzledSharedEncodingAttr::get (getContext (), vectorSize, perPhase,
2083+ maxPhase, sharedOrder, ctaLayout);
2084+ }
2085+
20592086// ===----------------------------------------------------------------------===//
20602087// Mma encoding
20612088// ===----------------------------------------------------------------------===//
@@ -2659,7 +2686,9 @@ struct TritonGPUInferLayoutInterface
26592686 auto parent = enc.getParent ();
26602687 auto parentLL = toLinearLayout (joinedShape, parent);
26612688
2662- if (isSplitCompatible (ctx, parentLL)) {
2689+ Attribute splitEnc;
2690+ auto result = inferSplitOpEncoding (parent, splitEnc, joinedShape, loc);
2691+ if (succeeded (result) && areLayoutsEquivalent (shape, splitEnc, srcEnc)) {
26632692 dstEnc = parent;
26642693 return success ();
26652694 }
@@ -2709,28 +2738,16 @@ struct TritonGPUInferLayoutInterface
27092738 inferSplitOpEncoding (Attribute srcEnc, Attribute &dstEnc,
27102739 ArrayRef<int64_t > shape,
27112740 std::optional<Location> loc) const override {
2741+ // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2742+ // shape AxBxC. The input must have 2 elements per thread in the last
2743+ // dimension, which must be the fastest running dimension. The result
2744+ // encoding is the same as the input, but with the last dimension removed.
27122745 auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
2713- if (enc) {
2714- // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2715- // shape AxBxC. The input must have 2 elements per thread in the last
2716- // dimension, which must be the fastest running dimension. The result
2717- // encoding is the same as the input, but with the last dimension removed.
2718- if (enc.getSizePerThread ().back () != 2 ) {
2719- return emitOptionalError (
2720- loc, " SplitOp requires 2 elements per thread in the "
2721- " last dimension of the input" );
2722- }
2723- if (enc.getThreadsPerWarp ().back () != 1 ||
2724- enc.getWarpsPerCTA ().back () != 1 || enc.getCTAsPerCGA ().back () != 1 ) {
2725- return emitOptionalError (
2726- loc, " SplitOp requires threadsPerWarp, warpsPerCTA, "
2727- " and CTAsPerCGA = 1 for the last dimension of the input" );
2728- }
2729- if (enc.getCTALayout ().getCTAsPerCGA ().back () != 1 ) {
2730- return emitOptionalError (
2731- loc,
2732- " SplitOp requires the last dimension to be most-minor in CTAOrder" );
2733- }
2746+ bool isSimpleSplit = (enc && (enc.getSizePerThread ().back () == 2 ) &&
2747+ (enc.getThreadsPerWarp ().back () == 1 ) &&
2748+ (enc.getWarpsPerCTA ().back () == 1 ) &&
2749+ (enc.getCTAsPerCGA ().back () == 1 ));
2750+ if (isSimpleSplit) {
27342751 SmallVector<unsigned > newOrder (enc.getOrder ());
27352752 int splitDim = newOrder.size () - 1 ;
27362753 // Remove splitDim from order.
0 commit comments