@@ -1598,11 +1598,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
15981598
15991599SmallVector<unsigned >
16001600LinearEncodingAttr::getElemsPerThread (ArrayRef<int64_t > shape, Type) const {
1601- // We can relax this assert by calling toLinearLayout rather than
1602- // getLinearLayout
1603- SmallVector<int32_t > shapeVec (shape.begin (), shape.end ());
1604- assert (shapeVec == llvm::to_vector (getLinearLayout ().getOutDimSizes ()));
1605- auto ll = getLinearLayout ();
1601+ // When broadcasting the layout the shape changes, otherwise the shape is
1602+ // the same as the shape of the tensor
1603+ // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
1604+ // the invariant that the shape of the LL is that of the tensor
1605+ // We choose the former for BC
1606+ auto ll = *toLinearLayout (shape);
16061607 return basesPerDim (ll, StringAttr::get (getContext (), " register" ));
16071608}
16081609
@@ -2623,8 +2624,8 @@ struct TritonGPUInferLayoutInterface
26232624 // contains elements [a,b,c,d] before the reshape, it contains those same
26242625 // elements after the reshape, they're just "renamed".
26252626 //
2626- // A dst encoding that satisfies this property does not exist for all inputs.
2627- // Here are some positive and negative examples.
2627+ // Using legacy layouts, a dst encoding that satisfies this property may not
2628+ // exist. Here are some positive and negative examples.
26282629 //
26292630 // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
26302631 // dim 1 is the fastest-changing in the dst, but the src has the opposite
@@ -2638,17 +2639,19 @@ struct TritonGPUInferLayoutInterface
26382639 // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
26392640 // contain the same elements as before.
26402641 //
2642+ // With linear layouts, we can always find a dst encoding that satisfies
2643+ // this property. See inferReshapeOpEncoding.
2644+ //
26412645 // Users of this function require that it is symmetrical: if
26422646 // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
26432647 // srcEnc.
2644- LogicalResult
2645- inferReshapeOpNoReorderEncoding (ArrayRef< int64_t > srcShape, Attribute srcEnc,
2646- ArrayRef<int64_t > dstShape, Attribute &dstEnc ,
2647- std::optional<Location> loc ) const override {
2648+ LogicalResult inferReshapeOpLegacyEncoding (ArrayRef< int64_t > srcShape,
2649+ Attribute srcEnc,
2650+ ArrayRef<int64_t > dstShape,
2651+ Attribute &dstEnc ) const {
26482652 auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
26492653 if (!src) {
2650- return emitOptionalError (
2651- loc, " Non-reordering reshape only supports BlockedEncoding" );
2654+ return failure ();
26522655 }
26532656
26542657 // Nop reshape; we can always infer an encoding.
@@ -2681,9 +2684,7 @@ struct TritonGPUInferLayoutInterface
26812684 // to handle CTASplitNum.
26822685 if (!all_of (src.getCTAsPerCGA (), [](int32_t x) { return x == 1 ; }) ||
26832686 !all_of (src.getCTASplitNum (), [](int32_t x) { return x == 1 ; })) {
2684- return emitOptionalError (
2685- loc, " Non-reordering reshape does not currently support multi-CTA "
2686- " layouts other than the default layout." );
2687+ return failure ();
26872688 }
26882689
26892690 // Cowardly refuse to handle encodings where shape[dim] is not divisible by
@@ -2693,12 +2694,7 @@ struct TritonGPUInferLayoutInterface
26932694 for (int dim = 0 ; dim < srcShape.size (); dim++) {
26942695 if (srcShape[dim] >= subblock[dim] &&
26952696 srcShape[dim] % subblock[dim] != 0 ) {
2696- return emitOptionalError (loc,
2697- " Can't do a non-reordering reshape because "
2698- " the size of dimension " ,
2699- dim, " (" , srcShape[dim], " )" ,
2700- " is not divisible by " , name, " [" , dim, " ]" ,
2701- " = " , subblock[dim]);
2697+ return failure ();
27022698 }
27032699 }
27042700 return success ();
@@ -2723,11 +2719,7 @@ struct TritonGPUInferLayoutInterface
27232719 // physical order, with `a` being the most major.
27242720 for (const auto &[srcDims, dstDims] : decomp) {
27252721 if (!isConsecutive (to_vector (reverse (gather (srcInvOrder, srcDims))))) {
2726- return emitOptionalError (loc,
2727- " Cannot do a non-reordering reshape given "
2728- " this src encoding order. Dimensions [" ,
2729- join (srcDims),
2730- " ] must be physically consecutive." );
2722+ return failure ();
27312723 }
27322724 }
27332725
@@ -2774,11 +2766,7 @@ struct TritonGPUInferLayoutInterface
27742766 // Check that more-minor dims all have 1 in shapeRemaining.
27752767 for (int j = i + 1 ; j < srcDims.size (); j++) {
27762768 if (shapeRemaining[j] != 1 ) {
2777- return emitOptionalError (
2778- loc,
2779- " Invalid src encoding for non-reordering reshape. Must use "
2780- " up sizePerThread / threadsPerWarp / warpsPerCTA for "
2781- " more-minor dimensions before more major-dims can use them." );
2769+ return failure ();
27822770 }
27832771 }
27842772
@@ -2793,13 +2781,7 @@ struct TritonGPUInferLayoutInterface
27932781 // only if we're the most-major dimension of the chunk and in all
27942782 // future chunks, only this most-major dim has a non-1 size.
27952783 if (shapeRemaining[i] == 0 && i != 0 ) {
2796- return emitOptionalError (
2797- loc,
2798- " Invalid src encoding for non-reordering reshape. Block "
2799- " size in dimension " ,
2800- dim,
2801- " is larger than the shape that dimension, but this is only "
2802- " allowed for the most-major dimension of a reshape chunk" );
2784+ return failure ();
28032785 }
28042786 }
28052787 return success ();
@@ -2889,6 +2871,65 @@ struct TritonGPUInferLayoutInterface
28892871 return success ();
28902872 }
28912873
2874+ LogicalResult verifyLayoutsAreEqual (ArrayRef<int64_t > shape,
2875+ Attribute expected, Attribute got,
2876+ Location loc) const override {
2877+ if (expected == got) {
2878+ return success ();
2879+ }
2880+ // Check whether the encodings are structurally the same.
2881+ auto expectedLL = triton::gpu::toLinearLayout (shape, expected);
2882+ auto gotLL = triton::gpu::toLinearLayout (shape, got);
2883+ if (expectedLL != gotLL) {
2884+ return emitError (loc, " Expected result encoding " )
2885+ << expected << " but was " << got;
2886+ }
2887+ return success ();
2888+ }
2889+
2890+ LogicalResult
2891+ inferReshapeOpEncoding (ArrayRef<int64_t > srcShape, Attribute srcEnc,
2892+ ArrayRef<int64_t > dstShape, Attribute &dstEnc,
2893+ std::optional<Location> loc) const override {
2894+ auto result =
2895+ inferReshapeOpLegacyEncoding (srcShape, srcEnc, dstShape, dstEnc);
2896+ if (succeeded (result)) {
2897+ return result;
2898+ }
2899+
2900+ // If the legacy encoding failed use LinearLayouts.
2901+ // Once LinearLayouts are more widely used, we can remove
2902+ // inferReshapeOpLegacyEncoding and simply use LLs.
2903+ auto *ctx = getContext ();
2904+ auto src = triton::gpu::toLinearLayout (srcShape, srcEnc);
2905+ if (!src) {
2906+ return emitOptionalError (loc,
2907+ " src encoding does not support linear layout" );
2908+ }
2909+
2910+ if (product (srcShape) != product (dstShape)) {
2911+ return emitOptionalError (loc, " numel of dst shape does not match "
2912+ " numel of src shape" );
2913+ }
2914+
2915+ auto newRank = dstShape.size ();
2916+ SmallVector<std::pair<StringAttr, int32_t >> newOutDims;
2917+ for (auto [dim, size] :
2918+ llvm::zip (standardOutDimNames (ctx, newRank), dstShape)) {
2919+ newOutDims.emplace_back (dim, size);
2920+ }
2921+ auto srcOutDims = llvm::to_vector (src->getOutDimNames ());
2922+ // reshapeOp assumes minor-to-major, so we need to transpose the out dims
2923+ // before the reshape
2924+ std::reverse (srcOutDims.begin (), srcOutDims.end ());
2925+ std::reverse (newOutDims.begin (), newOutDims.end ());
2926+ auto dst = src->transposeOuts (srcOutDims)
2927+ .reshapeOuts (newOutDims)
2928+ .transposeOuts (standardOutDimNames (ctx, newRank));
2929+ dstEnc = LinearEncodingAttr::get (ctx, dst);
2930+ return success ();
2931+ }
2932+
28922933 LogicalResult
28932934 inferJoinOpEncoding (Attribute srcEnc, Attribute &dstEnc,
28942935 std::optional<Location> loc) const override {
0 commit comments