@@ -488,6 +488,42 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
488488 return encoding;
489489}
490490
491+ LogicalResult tryJoinOnAxis (MLIRContext *ctx, const LinearLayout &inLl,
492+ LinearLayout &outLl, bool fwdInference, int axis,
493+ std::optional<Location> loc) {
494+ auto kRegister = StringAttr::get (ctx, " register" );
495+ auto outDims = llvm::to_vector (inLl.getOutDimNames ());
496+ if (fwdInference) {
497+ auto split = LinearLayout::identity1D (2 , kRegister , outDims[axis]);
498+ outLl = split * inLl;
499+ } else {
500+ // TODO This requires a division algorithm!
501+ // Implement manually ll.divideLeft(split)
502+ auto contiguousElems =
503+ LinearEncodingAttr::get (ctx, inLl).getContigPerThread ();
504+ if (contiguousElems[axis] > 1 ) {
505+ LinearLayout::BasesT newBases;
506+ for (const auto &basesDim : inLl.getBases ()) {
507+ std::vector<std::vector<int32_t >> newBasesDim;
508+ for (auto base : basesDim.second ) {
509+ if (base[axis] == 1 ) {
510+ continue ;
511+ }
512+ base[axis] /= 2 ;
513+ newBasesDim.push_back (std::move (base));
514+ }
515+ newBases.insert ({basesDim.first , std::move (newBasesDim)});
516+ }
517+ outLl = LinearLayout (std::move (newBases), std::move (outDims));
518+ } else {
519+ return emitOptionalError (loc,
520+ " Fp4ToFpOp/SplitOp requires at least 2 elements "
521+ " per thread in the axis/last dimension" );
522+ }
523+ }
524+ return success ();
525+ }
526+
491527} // namespace gpu
492528} // namespace triton
493529} // namespace mlir
@@ -1239,28 +1275,39 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
12391275 return scaledLayout.basesPerDim (kRegister , /* skipBroadcast=*/ false );
12401276}
12411277
1242- SmallVector<unsigned > LinearEncodingAttr::getContigPerThread () const {
1278+ SmallVector<unsigned >
1279+ LinearEncodingAttr::getContig (const char *inDim,
1280+ SmallVector<unsigned int > lowerContig) const {
12431281 auto ll = getLinearLayout ();
1244- const auto ®s =
1245- ll.getBases ().find (StringAttr::get (getContext (), " register " ))->second ;
1282+ const auto &bases =
1283+ ll.getBases ().find (StringAttr::get (getContext (), inDim ))->second ;
12461284 auto order = getOrder ();
12471285 auto rank = order.size ();
12481286
1249- SmallVector<unsigned > contig (rank, 1 );
1250- auto regIt = regs .begin ();
1287+ SmallVector<unsigned > contig (lowerContig );
1288+ auto basisIt = bases .begin ();
12511289 for (unsigned dim : order) {
12521290 std::vector<int32_t > basis (rank, 0 );
1253- basis[dim] = 1 ;
1291+ basis[dim] = contig[dim] ;
12541292
1255- while (regIt != regs .end () && *regIt == basis) {
1293+ while (basisIt != bases .end () && *basisIt == basis) {
12561294 contig[dim] *= 2 ;
12571295 basis[dim] *= 2 ;
1258- ++regIt ;
1296+ ++basisIt ;
12591297 }
12601298 }
12611299 return contig;
12621300}
12631301
1302+ SmallVector<unsigned > LinearEncodingAttr::getContigPerThread () const {
1303+ SmallVector<unsigned > contig (getOrder ().size (), 1 );
1304+ return getContig (" register" , contig);
1305+ }
1306+
1307+ SmallVector<unsigned > LinearEncodingAttr::getContigPerWarp () const {
1308+ return getContig (" lane" , getContigPerThread ());
1309+ }
1310+
12641311unsigned
12651312LinearEncodingAttr::getTotalElemsPerThread (ArrayRef<int64_t > shape) const {
12661313 return product (getElemsPerThread (shape));
@@ -2721,14 +2768,12 @@ struct TritonGPUInferLayoutInterface
27212768 }
27222769
27232770 auto newRank = dstShape.size ();
2724- SmallVector<std::pair<StringAttr, int32_t >> newOutDims;
2725- for (auto [dim, size] :
2726- llvm::zip (standardOutDimNames (ctx, newRank), dstShape)) {
2727- newOutDims.emplace_back (dim, size);
2728- }
2729- auto srcOutDims = to_vector (src.getOutDimNames ());
2771+
2772+ auto newOutDims = standardOutDimPairs (ctx, dstShape);
2773+
27302774 // reshapeOp assumes minor-to-major, so we need to transpose the out dims
27312775 // before the reshape
2776+ auto srcOutDims = to_vector (src.getOutDimNames ());
27322777 std::reverse (srcOutDims.begin (), srcOutDims.end ());
27332778 std::reverse (newOutDims.begin (), newOutDims.end ());
27342779 auto dst = src.transposeOuts (srcOutDims)
@@ -2740,82 +2785,117 @@ struct TritonGPUInferLayoutInterface
27402785
27412786 LogicalResult
27422787 inferJoinOpEncoding (Attribute srcEnc, Attribute &dstEnc,
2788+ ArrayRef<int64_t > shape,
27432789 std::optional<Location> loc) const override {
2744- auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
2745- if (!enc) {
2746- return emitOptionalError (loc,
2747- " JoinOp can only operate on BlockedEncoding" );
2790+ if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
2791+ // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
2792+ // AxBxCx2. The encoding is the same as the input, but with 2 elems per
2793+ // thread in the new dimension. The new dimension is most-minor.
2794+ auto append = [](ArrayRef<unsigned > vals, int val) {
2795+ SmallVector<unsigned > ret (vals);
2796+ ret.push_back (val);
2797+ return ret;
2798+ };
2799+ auto appendMinorDim = [](ArrayRef<unsigned > order) {
2800+ SmallVector<unsigned > ret (order);
2801+ ret.insert (ret.begin (), ret.size ());
2802+ return ret;
2803+ };
2804+ dstEnc = BlockedEncodingAttr::get (
2805+ enc.getContext (), //
2806+ append (enc.getSizePerThread (), 2 ), //
2807+ append (enc.getThreadsPerWarp (), 1 ), //
2808+ append (enc.getWarpsPerCTA (), 1 ), //
2809+ appendMinorDim (enc.getOrder ()), //
2810+ CTALayoutAttr::get (enc.getContext (), //
2811+ append (enc.getCTAsPerCGA (), 1 ),
2812+ append (enc.getCTASplitNum (), 1 ),
2813+ appendMinorDim (enc.getCTAOrder ())));
2814+ return success ();
27482815 }
27492816
2750- // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
2751- // AxBxCx2. The encoding is the same as the input, but with 2 elems per
2752- // thread in the new dimension. The new dimension is most-minor.
2753- auto append = [](ArrayRef<unsigned > vals, int val) {
2754- SmallVector<unsigned > ret (vals);
2755- ret.push_back (val);
2756- return ret;
2757- };
2758- auto appendMinorDim = [](ArrayRef<unsigned > order) {
2759- SmallVector<unsigned > ret (order);
2760- ret.insert (ret.begin (), ret.size ());
2761- return ret;
2762- };
2763- dstEnc = BlockedEncodingAttr::get (
2764- enc.getContext (), //
2765- append (enc.getSizePerThread (), 2 ), //
2766- append (enc.getThreadsPerWarp (), 1 ), //
2767- append (enc.getWarpsPerCTA (), 1 ), //
2768- appendMinorDim (enc.getOrder ()), //
2769- CTALayoutAttr::get (enc.getContext (), //
2770- append (enc.getCTAsPerCGA (), 1 ),
2771- append (enc.getCTASplitNum (), 1 ),
2772- appendMinorDim (enc.getCTAOrder ())));
2817+ auto ctx = getContext ();
2818+
2819+ // Append dim to shape
2820+ auto ll = toLinearLayout (shape, srcEnc);
2821+ SmallVector<int64_t > dstShape (shape.begin (), shape.end ());
2822+ dstShape.push_back (1 );
2823+ ll = ll.reshapeOuts (standardOutDimPairs (ctx, dstShape));
2824+
2825+ // Try join on last dim
2826+ auto axis = dstShape.size () - 1 ;
2827+ auto newLl = LinearLayout::empty ();
2828+ auto result =
2829+ tryJoinOnAxis (ctx, ll, newLl, /* fwdInference=*/ true , axis, loc);
2830+
2831+ assert (result.succeeded ());
2832+ dstEnc = LinearEncodingAttr::get (ctx, newLl);
27732833 return success ();
27742834 }
27752835
27762836 LogicalResult
27772837 inferSplitOpEncoding (Attribute srcEnc, Attribute &dstEnc,
2838+ ArrayRef<int64_t > shape,
27782839 std::optional<Location> loc) const override {
27792840 auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
2780- if (!enc) {
2781- return emitOptionalError (loc,
2782- " SplitOp can only operate on BlockedEncoding" );
2841+ if (enc) {
2842+ // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2843+ // shape AxBxC. The input must have 2 elements per thread in the last
2844+ // dimension, which must be most-minor. The result encoding is the same
2845+ // as the input, but with the last dimension removed.
2846+ if (enc.getSizePerThread ().back () != 2 ) {
2847+ return emitOptionalError (
2848+ loc, " SplitOp requires 2 elements per thread in the "
2849+ " last dimension of the input" );
2850+ }
2851+ if (enc.getThreadsPerWarp ().back () != 1 ||
2852+ enc.getWarpsPerCTA ().back () != 1 || enc.getCTAsPerCGA ().back () != 1 ) {
2853+ return emitOptionalError (
2854+ loc, " SplitOp requires threadsPerWarp, warpsPerCTA, "
2855+ " and CTAsPerCGA = 1 for the last dimension of the input" );
2856+ }
2857+ if (enc.getCTALayout ().getCTAsPerCGA ().back () != 1 ) {
2858+ return emitOptionalError (
2859+ loc,
2860+ " SplitOp requires the last dimension to be most-minor in CTAOrder" );
2861+ }
2862+ SmallVector<unsigned > newOrder (enc.getOrder ());
2863+ int splitDim = newOrder.size () - 1 ;
2864+ // Remove splitDim from order.
2865+ newOrder.erase (std::remove (newOrder.begin (), newOrder.end (), splitDim),
2866+ newOrder.end ());
2867+ dstEnc = BlockedEncodingAttr::get (
2868+ enc.getContext (), //
2869+ ArrayRef (enc.getSizePerThread ()).drop_back (1 ),
2870+ ArrayRef (enc.getThreadsPerWarp ()).drop_back (1 ),
2871+ ArrayRef (enc.getWarpsPerCTA ()).drop_back (1 ), ArrayRef (newOrder),
2872+ CTALayoutAttr::get (enc.getContext (), //
2873+ ArrayRef (enc.getCTAsPerCGA ()).drop_back (1 ),
2874+ ArrayRef (enc.getCTASplitNum ()).drop_back (1 ),
2875+ ArrayRef (enc.getCTAOrder ()).drop_front (1 )));
2876+ return success ();
27832877 }
27842878
2785- // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2786- // shape AxBxC. The input must have 2 elements per thread in the last
2787- // dimension, which must be most-minor. The result encoding is the same as
2788- // the input, but with the last dimension removed.
2789- if (enc.getSizePerThread ().back () != 2 ) {
2790- return emitOptionalError (loc,
2791- " SplitOp requires 2 elements per thread in the "
2792- " last dimension of the input" );
2793- }
2794- if (enc.getThreadsPerWarp ().back () != 1 ||
2795- enc.getWarpsPerCTA ().back () != 1 || enc.getCTAsPerCGA ().back () != 1 ) {
2796- return emitOptionalError (
2797- loc, " SplitOp requires threadsPerWarp, warpsPerCTA, "
2798- " and CTAsPerCGA = 1 for the last dimension of the input" );
2879+ auto axis = shape.size () - 1 ;
2880+ assert (shape[axis] == 2 &&
2881+ " SplitOp input shape should have 2 in the last dim" );
2882+
2883+ auto ctx = getContext ();
2884+
2885+ // Split on last dim
2886+ auto ll = toLinearLayout (shape, srcEnc);
2887+ auto newLl = LinearLayout::empty ();
2888+ auto result =
2889+ tryJoinOnAxis (ctx, ll, newLl, /* fwdInference=*/ false , axis, loc);
2890+ if (!result.succeeded ()) {
2891+ return failure ();
27992892 }
2800- if (enc.getCTALayout ().getCTAsPerCGA ().back () != 1 ) {
2801- return emitOptionalError (
2802- loc,
2803- " SplitOp requires the last dimension to be most-minor in CTAOrder" );
2804- }
2805- SmallVector<unsigned > newOrder (enc.getOrder ());
2806- int splitDim = newOrder.size () - 1 ;
2807- // Remove splitDim from order.
2808- newOrder.erase (std::remove (newOrder.begin (), newOrder.end (), splitDim),
2809- newOrder.end ());
2810- dstEnc = BlockedEncodingAttr::get (
2811- enc.getContext (), //
2812- ArrayRef (enc.getSizePerThread ()).drop_back (1 ),
2813- ArrayRef (enc.getThreadsPerWarp ()).drop_back (1 ),
2814- ArrayRef (enc.getWarpsPerCTA ()).drop_back (1 ), ArrayRef (newOrder),
2815- CTALayoutAttr::get (enc.getContext (), //
2816- ArrayRef (enc.getCTAsPerCGA ()).drop_back (1 ),
2817- ArrayRef (enc.getCTASplitNum ()).drop_back (1 ),
2818- ArrayRef (enc.getCTAOrder ()).drop_front (1 )));
2893+
2894+ // Remove last dim from newLl (which should be 1)
2895+ SmallVector<int64_t > dstShape (shape.begin (), shape.end ());
2896+ dstShape.pop_back ();
2897+ newLl = newLl.reshapeOuts (standardOutDimPairs (ctx, dstShape));
2898+ dstEnc = LinearEncodingAttr::get (ctx, newLl);
28192899 return success ();
28202900 }
28212901
@@ -2873,37 +2953,10 @@ struct TritonGPUInferLayoutInterface
28732953 }
28742954
28752955 auto ll = toLinearLayout (shape, inEnc);
2876-
2877- auto kRegister = StringAttr::get (ctx, " register" );
2878- auto outDims = llvm::to_vector (ll.getOutDimNames ());
2879- LinearLayout newLl = LinearLayout::empty ();
2880- if (fwdInference) {
2881- auto split = LinearLayout::identity1D (2 , kRegister , outDims[axis]);
2882- newLl = split * ll;
2883- } else {
2884- // TODO This requires a division algorithm!
2885- // Implement manually ll.divideLeft(split)
2886- auto contiguousElems =
2887- LinearEncodingAttr::get (ctx, ll).getContigPerThread ();
2888- if (contiguousElems[axis] > 1 ) {
2889- LinearLayout::BasesT newBases;
2890- for (const auto &basesDim : ll.getBases ()) {
2891- std::vector<std::vector<int32_t >> newBasesDim;
2892- for (auto base : basesDim.second ) {
2893- if (base[axis] == 1 ) {
2894- continue ;
2895- }
2896- base[axis] /= 2 ;
2897- newBasesDim.push_back (std::move (base));
2898- }
2899- newBases.insert ({basesDim.first , std::move (newBasesDim)});
2900- }
2901- newLl = LinearLayout (std::move (newBases), std::move (outDims));
2902- } else {
2903- return emitOptionalError (loc, " Fp4ToFpOp requires at least 2 elements "
2904- " per thread in the axis dimension" );
2905- }
2906- }
2956+ auto newLl = LinearLayout::empty ();
2957+ auto result = tryJoinOnAxis (ctx, ll, newLl, fwdInference, axis, loc);
2958+ if (!result.succeeded ())
2959+ return result;
29072960 outEnc = LinearEncodingAttr::get (ctx, newLl);
29082961 return success ();
29092962 }
0 commit comments