@@ -1630,11 +1630,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
16301630
16311631SmallVector<unsigned >
16321632LinearEncodingAttr::getElemsPerThread (ArrayRef<int64_t > shape, Type) const {
1633- // We can relax this assert by calling toLinearLayout rather than
1634- // getLinearLayout
1635- SmallVector<int32_t > shapeVec (shape.begin (), shape.end ());
1636- assert (shapeVec == llvm::to_vector (getLinearLayout ().getOutDimSizes ()));
1637- auto ll = getLinearLayout ();
1633+ // When broadcasting the layout the shape changes, otherwise the shape is
1634+ // the same as the shape of the tensor
1635+ // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
1636+ // the invariant that the shape of the LL is that of the tensor
1637+ // We choose the former for BC
1638+ auto ll = *toLinearLayout (shape);
16381639 return basesPerDim (ll, StringAttr::get (getContext (), " register" ));
16391640}
16401641
@@ -2658,8 +2659,8 @@ struct TritonGPUInferLayoutInterface
26582659 // contains elements [a,b,c,d] before the reshape, it contains those same
26592660 // elements after the reshape, they're just "renamed".
26602661 //
2661- // A dst encoding that satisfies this property does not exist for all inputs.
2662- // Here are some positive and negative examples.
2662+ // Using legacy layouts, a dst encoding that satisfies this property may not
2663+ // exist. Here are some positive and negative examples.
26632664 //
26642665 // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
26652666 // dim 1 is the fastest-changing in the dst, but the src has the opposite
@@ -2673,17 +2674,19 @@ struct TritonGPUInferLayoutInterface
26732674 // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
26742675 // contain the same elements as before.
26752676 //
2677+ // With linear layouts, we can always find a dst encoding that satisfies
2678+ // this property. See inferReshapeOpEncoding.
2679+ //
26762680 // Users of this function require that it is symmetrical: if
26772681 // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
26782682 // srcEnc.
2679- LogicalResult
2680- inferReshapeOpNoReorderEncoding (ArrayRef< int64_t > srcShape, Attribute srcEnc,
2681- ArrayRef<int64_t > dstShape, Attribute &dstEnc ,
2682- std::optional<Location> loc ) const override {
2683+ LogicalResult inferReshapeOpLegacyEncoding (ArrayRef< int64_t > srcShape,
2684+ Attribute srcEnc,
2685+ ArrayRef<int64_t > dstShape,
2686+ Attribute &dstEnc ) const {
26832687 auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
26842688 if (!src) {
2685- return emitOptionalError (
2686- loc, " Non-reordering reshape only supports BlockedEncoding" );
2689+ return failure ();
26872690 }
26882691
26892692 // Nop reshape; we can always infer an encoding.
@@ -2716,9 +2719,7 @@ struct TritonGPUInferLayoutInterface
27162719 // to handle CTASplitNum.
27172720 if (!all_of (src.getCTAsPerCGA (), [](int32_t x) { return x == 1 ; }) ||
27182721 !all_of (src.getCTASplitNum (), [](int32_t x) { return x == 1 ; })) {
2719- return emitOptionalError (
2720- loc, " Non-reordering reshape does not currently support multi-CTA "
2721- " layouts other than the default layout." );
2722+ return failure ();
27222723 }
27232724
27242725 // Cowardly refuse to handle encodings where shape[dim] is not divisible by
@@ -2728,12 +2729,7 @@ struct TritonGPUInferLayoutInterface
27282729 for (int dim = 0 ; dim < srcShape.size (); dim++) {
27292730 if (srcShape[dim] >= subblock[dim] &&
27302731 srcShape[dim] % subblock[dim] != 0 ) {
2731- return emitOptionalError (loc,
2732- " Can't do a non-reordering reshape because "
2733- " the size of dimension " ,
2734- dim, " (" , srcShape[dim], " )" ,
2735- " is not divisible by " , name, " [" , dim, " ]" ,
2736- " = " , subblock[dim]);
2732+ return failure ();
27372733 }
27382734 }
27392735 return success ();
@@ -2758,11 +2754,7 @@ struct TritonGPUInferLayoutInterface
27582754 // physical order, with `a` being the most major.
27592755 for (const auto &[srcDims, dstDims] : decomp) {
27602756 if (!isConsecutive (to_vector (reverse (gather (srcInvOrder, srcDims))))) {
2761- return emitOptionalError (loc,
2762- " Cannot do a non-reordering reshape given "
2763- " this src encoding order. Dimensions [" ,
2764- join (srcDims),
2765- " ] must be physically consecutive." );
2757+ return failure ();
27662758 }
27672759 }
27682760
@@ -2809,11 +2801,7 @@ struct TritonGPUInferLayoutInterface
28092801 // Check that more-minor dims all have 1 in shapeRemaining.
28102802 for (int j = i + 1 ; j < srcDims.size (); j++) {
28112803 if (shapeRemaining[j] != 1 ) {
2812- return emitOptionalError (
2813- loc,
2814- " Invalid src encoding for non-reordering reshape. Must use "
2815- " up sizePerThread / threadsPerWarp / warpsPerCTA for "
2816- " more-minor dimensions before more major-dims can use them." );
2804+ return failure ();
28172805 }
28182806 }
28192807
@@ -2828,13 +2816,7 @@ struct TritonGPUInferLayoutInterface
28282816 // only if we're the most-major dimension of the chunk and in all
28292817 // future chunks, only this most-major dim has a non-1 size.
28302818 if (shapeRemaining[i] == 0 && i != 0 ) {
2831- return emitOptionalError (
2832- loc,
2833- " Invalid src encoding for non-reordering reshape. Block "
2834- " size in dimension " ,
2835- dim,
2836- " is larger than the shape that dimension, but this is only "
2837- " allowed for the most-major dimension of a reshape chunk" );
2819+ return failure ();
28382820 }
28392821 }
28402822 return success ();
@@ -2924,6 +2906,65 @@ struct TritonGPUInferLayoutInterface
29242906 return success ();
29252907 }
29262908
2909+ LogicalResult verifyLayoutsAreEqual (ArrayRef<int64_t > shape,
2910+ Attribute expected, Attribute got,
2911+ Location loc) const override {
2912+ if (expected == got) {
2913+ return success ();
2914+ }
2915+ // Check whether the encodings are structurally the same.
2916+ auto expectedLL = triton::gpu::toLinearLayout (shape, expected);
2917+ auto gotLL = triton::gpu::toLinearLayout (shape, got);
2918+ if (expectedLL != gotLL) {
2919+ return emitError (loc, " Expected result encoding " )
2920+ << expected << " but was " << got;
2921+ }
2922+ return success ();
2923+ }
2924+
2925+ LogicalResult
2926+ inferReshapeOpEncoding (ArrayRef<int64_t > srcShape, Attribute srcEnc,
2927+ ArrayRef<int64_t > dstShape, Attribute &dstEnc,
2928+ std::optional<Location> loc) const override {
2929+ auto result =
2930+ inferReshapeOpLegacyEncoding (srcShape, srcEnc, dstShape, dstEnc);
2931+ if (succeeded (result)) {
2932+ return result;
2933+ }
2934+
2935+ // If the legacy encoding failed use LinearLayouts.
2936+ // Once LinearLayouts are more widely used, we can remove
2937+ // inferReshapeOpLegacyEncoding and simply use LLs.
2938+ auto *ctx = getContext ();
2939+ auto src = triton::gpu::toLinearLayout (srcShape, srcEnc);
2940+ if (!src) {
2941+ return emitOptionalError (loc,
2942+ " src encoding does not support linear layout" );
2943+ }
2944+
2945+ if (product (srcShape) != product (dstShape)) {
2946+ return emitOptionalError (loc, " numel of dst shape does not match "
2947+ " numel of src shape" );
2948+ }
2949+
2950+ auto newRank = dstShape.size ();
2951+ SmallVector<std::pair<StringAttr, int32_t >> newOutDims;
2952+ for (auto [dim, size] :
2953+ llvm::zip (standardOutDimNames (ctx, newRank), dstShape)) {
2954+ newOutDims.emplace_back (dim, size);
2955+ }
2956+ auto srcOutDims = llvm::to_vector (src->getOutDimNames ());
2957+ // reshapeOp assumes minor-to-major, so we need to transpose the out dims
2958+ // before the reshape
2959+ std::reverse (srcOutDims.begin (), srcOutDims.end ());
2960+ std::reverse (newOutDims.begin (), newOutDims.end ());
2961+ auto dst = src->transposeOuts (srcOutDims)
2962+ .reshapeOuts (newOutDims)
2963+ .transposeOuts (standardOutDimNames (ctx, newRank));
2964+ dstEnc = LinearEncodingAttr::get (ctx, dst);
2965+ return success ();
2966+ }
2967+
29272968 LogicalResult
29282969 inferJoinOpEncoding (Attribute srcEnc, Attribute &dstEnc,
29292970 std::optional<Location> loc) const override {
0 commit comments