@@ -1441,7 +1441,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
14411441
14421442 SmallVector<unsigned > ret (rank, 1 );
14431443 auto nonZero = [](auto val) { return val != 0 ; };
1444- int nonZeroIdx = 0 ;
1444+ int nonZeroIdx = - 1 ;
14451445 for (const auto &basis : bases) {
14461446 auto it = std::find_if (basis.begin (), basis.end (), nonZero);
14471447 // Bases can have one or zero non-zero elements
@@ -1453,6 +1453,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
14531453 } else if (!skipBroadcast) {
14541454 // If we've seen a non-zero basis, we double the size of the previous dim
14551455 // This is just needed to count the CTAsPerCGA
1456+ assert (nonZeroIdx != -1 );
14561457 ret[nonZeroIdx] *= 2 ;
14571458 }
14581459 }
@@ -1597,14 +1598,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
15971598
15981599SmallVector<unsigned >
15991600LinearEncodingAttr::getElemsPerThread (ArrayRef<int64_t > shape, Type) const {
1600- // When broadcasting the layout the shape changes, otherwise the shape is
1601- // the same as the shape of the tensor
1602- // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
1603- // the invariant that the shape of the LL is that of the tensor
1604- // We choose the former for BC
1605- auto ll = *toLinearLayout (shape);
1606- return basesPerDim (ll, StringAttr::get (getContext (), " register" ),
1607- /* skipBroadcast=*/ false );
1601+ // We can relax this assert by calling toLinearLayout rather than
1602+ // getLinearLayout
1603+ SmallVector<int32_t > shapeVec (shape.begin (), shape.end ());
1604+ assert (shapeVec == llvm::to_vector (getLinearLayout ().getOutDimSizes ()));
1605+ auto ll = getLinearLayout ();
1606+ return basesPerDim (ll, StringAttr::get (getContext (), " register" ));
16081607}
16091608
16101609// Start of Selection
@@ -2674,8 +2673,8 @@ struct TritonGPUInferLayoutInterface
26742673 // contains elements [a,b,c,d] before the reshape, it contains those same
26752674 // elements after the reshape, they're just "renamed".
26762675 //
2677- // Using legacy layouts, a dst encoding that satisfies this property may not
2678- // exist. Here are some positive and negative examples.
2676+ // A dst encoding that satisfies this property does not exist for all inputs.
2677+ // Here are some positive and negative examples.
26792678 //
26802679 // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
26812680 // dim 1 is the fastest-changing in the dst, but the src has the opposite
@@ -2689,19 +2688,17 @@ struct TritonGPUInferLayoutInterface
26892688 // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
26902689 // contain the same elements as before.
26912690 //
2692- // With linear layouts, we can always find a dst encoding that satisfies
2693- // this property. See inferReshapeOpEncoding.
2694- //
26952691 // Users of this function require that it is symmetrical: if
26962692 // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
26972693 // srcEnc.
2698- LogicalResult inferReshapeOpLegacyEncoding (ArrayRef< int64_t > srcShape,
2699- Attribute srcEnc,
2700- ArrayRef<int64_t > dstShape,
2701- Attribute &dstEnc ) const {
2694+ LogicalResult
2695+ inferReshapeOpNoReorderEncoding (ArrayRef< int64_t > srcShape, Attribute srcEnc,
2696+ ArrayRef<int64_t > dstShape, Attribute &dstEnc ,
2697+ std::optional<Location> loc ) const override {
27022698 auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
27032699 if (!src) {
2704- return failure ();
2700+ return emitOptionalError (
2701+ loc, " Non-reordering reshape only supports BlockedEncoding" );
27052702 }
27062703
27072704 // Nop reshape; we can always infer an encoding.
@@ -2734,7 +2731,9 @@ struct TritonGPUInferLayoutInterface
27342731 // to handle CTASplitNum.
27352732 if (!all_of (src.getCTAsPerCGA (), [](int32_t x) { return x == 1 ; }) ||
27362733 !all_of (src.getCTASplitNum (), [](int32_t x) { return x == 1 ; })) {
2737- return failure ();
2734+ return emitOptionalError (
2735+ loc, " Non-reordering reshape does not currently support multi-CTA "
2736+ " layouts other than the default layout." );
27382737 }
27392738
27402739 // Cowardly refuse to handle encodings where shape[dim] is not divisible by
@@ -2744,7 +2743,12 @@ struct TritonGPUInferLayoutInterface
27442743 for (int dim = 0 ; dim < srcShape.size (); dim++) {
27452744 if (srcShape[dim] >= subblock[dim] &&
27462745 srcShape[dim] % subblock[dim] != 0 ) {
2747- return failure ();
2746+ return emitOptionalError (loc,
2747+ " Can't do a non-reordering reshape because "
2748+ " the size of dimension " ,
2749+ dim, " (" , srcShape[dim], " )" ,
2750+ " is not divisible by " , name, " [" , dim, " ]" ,
2751+ " = " , subblock[dim]);
27482752 }
27492753 }
27502754 return success ();
@@ -2769,7 +2773,11 @@ struct TritonGPUInferLayoutInterface
27692773 // physical order, with `a` being the most major.
27702774 for (const auto &[srcDims, dstDims] : decomp) {
27712775 if (!isConsecutive (to_vector (reverse (gather (srcInvOrder, srcDims))))) {
2772- return failure ();
2776+ return emitOptionalError (loc,
2777+ " Cannot do a non-reordering reshape given "
2778+ " this src encoding order. Dimensions [" ,
2779+ join (srcDims),
2780+ " ] must be physically consecutive." );
27732781 }
27742782 }
27752783
@@ -2816,7 +2824,11 @@ struct TritonGPUInferLayoutInterface
28162824 // Check that more-minor dims all have 1 in shapeRemaining.
28172825 for (int j = i + 1 ; j < srcDims.size (); j++) {
28182826 if (shapeRemaining[j] != 1 ) {
2819- return failure ();
2827+ return emitOptionalError (
2828+ loc,
2829+ " Invalid src encoding for non-reordering reshape. Must use "
2830+ " up sizePerThread / threadsPerWarp / warpsPerCTA for "
2831+ " more-minor dimensions before more major-dims can use them." );
28202832 }
28212833 }
28222834
@@ -2831,7 +2843,13 @@ struct TritonGPUInferLayoutInterface
28312843 // only if we're the most-major dimension of the chunk and in all
28322844 // future chunks, only this most-major dim has a non-1 size.
28332845 if (shapeRemaining[i] == 0 && i != 0 ) {
2834- return failure ();
2846+ return emitOptionalError (
2847+ loc,
2848+ " Invalid src encoding for non-reordering reshape. Block "
2849+ " size in dimension " ,
2850+ dim,
2851+ " is larger than the shape that dimension, but this is only "
2852+ " allowed for the most-major dimension of a reshape chunk" );
28352853 }
28362854 }
28372855 return success ();
@@ -2921,65 +2939,6 @@ struct TritonGPUInferLayoutInterface
29212939 return success ();
29222940 }
29232941
2924- LogicalResult verifyLayoutsAreEqual (ArrayRef<int64_t > shape,
2925- Attribute expected, Attribute got,
2926- Location loc) const override {
2927- if (expected == got) {
2928- return success ();
2929- }
2930- // Check whether the encodings are structurally the same.
2931- auto expectedLL = triton::gpu::toLinearLayout (shape, expected);
2932- auto gotLL = triton::gpu::toLinearLayout (shape, got);
2933- if (expectedLL != gotLL) {
2934- return emitError (loc, " Expected result encoding " )
2935- << expected << " but was " << got;
2936- }
2937- return success ();
2938- }
2939-
2940- LogicalResult
2941- inferReshapeOpEncoding (ArrayRef<int64_t > srcShape, Attribute srcEnc,
2942- ArrayRef<int64_t > dstShape, Attribute &dstEnc,
2943- std::optional<Location> loc) const override {
2944- auto result =
2945- inferReshapeOpLegacyEncoding (srcShape, srcEnc, dstShape, dstEnc);
2946- if (succeeded (result)) {
2947- return result;
2948- }
2949-
2950- // If the legacy encoding failed use LinearLayouts.
2951- // Once LinearLayouts are more widely used, we can remove
2952- // inferReshapeOpLegacyEncoding and simply use LLs.
2953- auto *ctx = getContext ();
2954- auto src = triton::gpu::toLinearLayout (srcShape, srcEnc);
2955- if (!src) {
2956- return emitOptionalError (loc,
2957- " src encoding does not support linear layout" );
2958- }
2959-
2960- if (product (srcShape) != product (dstShape)) {
2961- return emitOptionalError (loc, " numel of dst shape does not match "
2962- " numel of src shape" );
2963- }
2964-
2965- auto newRank = dstShape.size ();
2966- SmallVector<std::pair<StringAttr, int32_t >> newOutDims;
2967- for (auto [dim, size] :
2968- llvm::zip (standardOutDimNames (ctx, newRank), dstShape)) {
2969- newOutDims.emplace_back (dim, size);
2970- }
2971- auto srcOutDims = llvm::to_vector (src->getOutDimNames ());
2972- // reshapeOp assumes minor-to-major, so we need to transpose the out dims
2973- // before the reshape
2974- std::reverse (srcOutDims.begin (), srcOutDims.end ());
2975- std::reverse (newOutDims.begin (), newOutDims.end ());
2976- auto dst = src->transposeOuts (srcOutDims)
2977- .reshapeOuts (newOutDims)
2978- .transposeOuts (standardOutDimNames (ctx, newRank));
2979- dstEnc = LinearEncodingAttr::get (ctx, dst);
2980- return success ();
2981- }
2982-
29832942 LogicalResult
29842943 inferJoinOpEncoding (Attribute srcEnc, Attribute &dstEnc,
29852944 std::optional<Location> loc) const override {
0 commit comments