11#include " triton/Dialect/Triton/IR/Dialect.h"
22
3+ #include < cstdint>
34#include < numeric>
45
56#include " mlir/IR/DialectImplementation.h"
@@ -250,6 +251,30 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
250251 return order;
251252}
252253
254+ SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank) {
255+ SmallVector<unsigned > order (rank);
256+ // The 'order' field typically represents a descending sorted array of
257+ // dimensions based on contiguity. For instance, in axisInfo utilities that
258+ // retrieve tensor contiguity, it's assumed that the dimension with the
259+ // highest contiguity corresponds to order[0].
260+ //
261+ // The relation between contiguity and order is only relevant if the layout
262+ // interfaces with HBM, as is the case when we load tensor from HBM to
263+ // registers in the dot layout to bypass LDS. When bypassing LDS, we make the
264+ // following assumptions about tensor layouts:
265+ // - Tensor A (opIdx == 0) is considered to be row-major.
266+ // - Tensor B (opIdx == 1) is considered to be column-major.
267+ //
268+ // Based on these assumptions, we define the following orders:
269+ // - For opIdx == 0, we assume an order of [1, 0].
270+ // - For opIdx == 1, we assume an order of [0, 1].
271+ std::iota (order.rbegin (), order.rend (), 0 );
272+ if (opIdx == 1 ) {
273+ std::swap (order[0 ], order[1 ]);
274+ }
275+ return order;
276+ }
277+
253278SmallVector<unsigned > getOrder (Attribute layout) {
254279 if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
255280 return llvm::to_vector (blockedLayout.getOrder ());
@@ -264,7 +289,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
264289 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
265290 auto rank = getWarpsPerCTA (dotLayout.getParent ()).size ();
266291 SmallVector<unsigned > order (rank);
267- std::iota (order.rbegin (), order.rend (), 0 );
292+ if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent ())) {
293+ return getOrderForDotOperand (dotLayout.getOpIdx (), rank);
294+ } else {
295+ std::iota (order.rbegin (), order.rend (), 0 );
296+ }
268297 return order;
269298 }
270299 if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -928,6 +957,27 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
928957SmallVector<unsigned >
929958DotOperandEncodingAttr::getElemsPerThread (ArrayRef<int64_t > shape,
930959 Type eltTy) const {
960+
961+ if (auto parent = mlir::dyn_cast<AMDMfmaEncodingAttr>(getParent ())) {
962+ auto rank = shape.size ();
963+ assert (rank == 2 || rank == 3 );
964+
965+ auto idx = getOpIdx ();
966+ assert (idx == 0 || idx == 1 );
967+
968+ SmallVector<unsigned > elemsPerThread (rank);
969+
970+ auto kWidth = getKWidth ();
971+ auto rep = parent.getMFMARepForOperands (shape, kWidth , idx);
972+
973+ if (rank == 3 )
974+ elemsPerThread[0 ] = rep[0 ];
975+ elemsPerThread[rank - 2 ] = (idx == 0 ) ? rep[1 ] : rep[1 ] * kWidth ;
976+ elemsPerThread[rank - 1 ] = (idx == 0 ) ? rep[2 ] * kWidth : rep[2 ];
977+
978+ return elemsPerThread;
979+ }
980+
931981 if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent ())) {
932982 return mmaParent.getElemsPerThreadForOperands (shape, eltTy, getOpIdx ());
933983 }
@@ -3107,8 +3157,124 @@ static std::string paddedString(int value, int max) {
31073157 return str;
31083158}
31093159
3110- std::string mlir::triton::gpu::getLayoutStr (RankedTensorType tensorType,
3111- bool useHWPointOfView) {
3160+ std::string getSharedLayoutStr (RankedTensorType tensorType,
3161+ bool useHWPointOfView) {
3162+ auto layout = tensorType.getEncoding ();
3163+ if (!layout)
3164+ return " " ;
3165+
3166+ std::optional<LinearLayout> ll =
3167+ triton::gpu::toLinearLayout (tensorType.getShape (), layout);
3168+ if (!ll.has_value ())
3169+ llvm::report_fatal_error (" Failed to convert layout to linear layout" );
3170+
3171+ StringAttr kOffset = StringAttr::get (tensorType.getContext (), " offset" );
3172+ StringAttr kBlock = StringAttr::get (tensorType.getContext (), " block" );
3173+ int64_t tensorSize = product (tensorType.getShape ());
3174+ unsigned numBlocks = getNumCTAs (layout);
3175+ int32_t blockSize = tensorSize / numBlocks;
3176+
3177+ // elementMapping is for the non-hw layout, offsetMapping for hw-layout
3178+ std::vector<std::string> elementMapping (tensorSize);
3179+ std::vector<std::string> offsetMapping;
3180+
3181+ // Shared layouts are a mapping of (block, offset) --> (...)
3182+
3183+ // We can just use a single int to index into elementMapping because
3184+ // the 'swizzle' operation rearranges the indicies---and we want to keep it
3185+ // that way
3186+ int32_t idx = 0 ;
3187+ // Enumerate all the offsets for each block
3188+ for (int32_t block = 0 ; block < numBlocks; block++) {
3189+ for (int32_t offset = 0 ; offset < blockSize; offset++) {
3190+ SmallVector<std::pair<StringAttr, int32_t >> inputs = {
3191+ {kBlock , block},
3192+ {kOffset , offset},
3193+ };
3194+
3195+ SmallVector<std::pair<StringAttr, int32_t >> outputs = ll->apply (inputs);
3196+
3197+ std::string sharedInfo = " (" ;
3198+ std::string &value = elementMapping[idx];
3199+
3200+ if (!value.empty ())
3201+ value += " |" ;
3202+
3203+ value += " (" ;
3204+ // We can build up both strings (for hw/non-hw layouts) concurrently
3205+ for (int i = 0 ; i < outputs.size (); i++) {
3206+ // Based on the formatting from LinearLayout::toString, the format for
3207+ // the hw layout is slightly different. HW layouts use "," vs ":".
3208+ if (i > 0 ) {
3209+ sharedInfo += " ," ;
3210+ value += " :" ;
3211+ }
3212+ auto index = paddedString (outputs[i].second , tensorType.getDimSize (i));
3213+ sharedInfo += index;
3214+ value += index;
3215+ }
3216+ value += " )" ;
3217+ sharedInfo += " )" ;
3218+
3219+ offsetMapping.push_back (sharedInfo);
3220+
3221+ idx++;
3222+ }
3223+ }
3224+
3225+ std::string layoutStr;
3226+
3227+ if (!useHWPointOfView) {
3228+ int rank = tensorType.getRank ();
3229+ bool newLine = true ;
3230+ for (int i = 0 ; i < tensorSize; i++) {
3231+ auto indices = delinearizeIndex (i, tensorType.getShape ());
3232+ int numOpenBracket = 0 ;
3233+ for (int j = rank - 1 ; j >= 0 ; j--) {
3234+ if (indices[j] % tensorType.getDimSize (j) != 0 )
3235+ break ;
3236+ layoutStr += " [" ;
3237+ numOpenBracket++;
3238+ }
3239+ if (newLine) {
3240+ for (int j = 0 ; j < rank - numOpenBracket; j++)
3241+ layoutStr += " " ;
3242+ newLine = false ;
3243+ }
3244+
3245+ layoutStr += elementMapping[i];
3246+ auto nextIndices = delinearizeIndex (i + 1 , tensorType.getShape ());
3247+ for (int j = rank - 1 ; j >= 0 ; j--) {
3248+ if (nextIndices[j] % tensorType.getDimSize (j) != 0 )
3249+ break ;
3250+ layoutStr += " ]" ;
3251+ }
3252+ if (nextIndices.back () % tensorType.getShape ().back () == 0 ) {
3253+ layoutStr += " \n " ;
3254+ newLine = true ;
3255+ } else {
3256+ layoutStr += " ," ;
3257+ }
3258+ }
3259+ } else {
3260+ // For the HW view here, print the (block, offset) --> (r,c) mapping
3261+ uint32_t idx = 0 ;
3262+ for (int32_t block = 0 ; block < numBlocks; block++) {
3263+ layoutStr += " Block: " + std::to_string (block) + " :\n " ;
3264+ for (int32_t offset = 0 ; offset < (tensorSize / numBlocks); offset++) {
3265+ layoutStr += " Offset: " + std::to_string (offset) + " -> " ;
3266+ layoutStr += offsetMapping[idx];
3267+ layoutStr += " \n " ;
3268+ idx++;
3269+ }
3270+ }
3271+ }
3272+
3273+ return layoutStr;
3274+ }
3275+
3276+ std::string getDistributedLayoutStr (RankedTensorType tensorType,
3277+ bool useHWPointOfView) {
31123278 auto layout = tensorType.getEncoding ();
31133279 if (!layout)
31143280 return " " ;
@@ -3175,7 +3341,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
31753341 }
31763342 std::string layoutStr;
31773343 if (!useHWPointOfView) {
3178- // Printing the threads containning each elements of the tensor.
3344+ // Printing the threads containing each elements of the tensor.
31793345 int rank = tensorType.getRank ();
31803346 bool newLine = true ;
31813347 for (int i = 0 ; i < tensorSize; i++) {
@@ -3233,6 +3399,24 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
32333399 return layoutStr;
32343400}
32353401
3402+ std::string mlir::triton::gpu::getLayoutStr (RankedTensorType tensorType,
3403+ bool useHWPointOfView) {
3404+ auto layout = tensorType.getEncoding ();
3405+
3406+ // tensorType is needed later on (e.g., getDimSize(j)), so we still have to
3407+ // pass it as a param
3408+ if (auto sharedLayout = mlir::dyn_cast<SharedEncodingAttr>(layout)) {
3409+ return getSharedLayoutStr (tensorType, useHWPointOfView);
3410+ } else if (auto distributedLayout =
3411+ mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
3412+ return getDistributedLayoutStr (tensorType, useHWPointOfView);
3413+ }
3414+
3415+ // else unimplemented, return error
3416+ llvm::report_fatal_error (" Unimplemented usage of getLayoutStr" );
3417+ return " " ;
3418+ }
3419+
32363420void mlir::triton::gpu::dumpLayout (RankedTensorType tensorType) {
32373421 llvm::errs () << getLayoutStr (tensorType, /* useHWPointOfView=*/ false );
32383422}
0 commit comments