@@ -3307,20 +3307,17 @@ static std::string paddedString(int value, int max) {
33073307 return str;
33083308}
33093309
3310- std::string getSharedLayoutStr (RankedTensorType type, bool useHWPointOfView) {
3311- if (!type)
3312- return " " ;
3313-
3310+ std::string mlir::triton::gpu::getSharedLayoutStr (LinearLayout &ll,
3311+ bool useHWPointOfView) {
33143312 // This RankedTensorType is a MemDescType (?!)
3315- auto shape = type. getShape ( );
3316- auto layout = type. getEncoding ( );
3317- LinearLayout ll = triton::gpu::toLinearLayout (shape, layout );
3313+ auto outDimNames = llvm::to_vector (ll. getOutDimNames () );
3314+ auto shape = convertType< int64_t >( llvm::to_vector (ll. getOutDimSizes ()) );
3315+ auto *ctx = outDimNames[ 0 ]. getContext ( );
33183316
3319- StringAttr kOffset = StringAttr::get (type.getContext (), " offset" );
3320- StringAttr kBlock = StringAttr::get (type.getContext (), " block" );
3321- int64_t tensorSize = product (type.getShape ());
3322- auto enc = type.getEncoding ();
3323- unsigned numBlocks = getNumCTAs (enc);
3317+ StringAttr kOffset = StringAttr::get (ctx, " offset" );
3318+ StringAttr kBlock = StringAttr::get (ctx, " block" );
3319+ int64_t tensorSize = product (shape);
3320+ unsigned numBlocks = ll.getInDimSize (kBlock );
33243321 int32_t blockSize = tensorSize / numBlocks;
33253322
33263323 // elementMapping is for the non-hw layout, offsetMapping for hw-layout
@@ -3374,7 +3371,7 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
33743371 std::string layoutStr;
33753372
33763373 if (!useHWPointOfView) {
3377- int rank = type. getRank ();
3374+ int rank = shape. size ();
33783375 bool newLine = true ;
33793376 for (int i = 0 ; i < tensorSize; i++) {
33803377 auto indices = delinearizeIndex (i, shape);
@@ -3422,21 +3419,19 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
34223419 return layoutStr;
34233420}
34243421
3425- std::string getDistributedLayoutStr (RankedTensorType tensorType,
3426- bool useHWPointOfView) {
3427- auto layout = tensorType.getEncoding ();
3428- if (!layout)
3429- return " " ;
3430-
3431- StringAttr kRegister = StringAttr::get (tensorType.getContext (), " register" );
3432- StringAttr kLane = StringAttr::get (tensorType.getContext (), " lane" );
3433- StringAttr kWarp = StringAttr::get (tensorType.getContext (), " warp" );
3434- StringAttr kBlock = StringAttr::get (tensorType.getContext (), " block" );
3422+ std::string mlir::triton::gpu::getDistributedLayoutStr (LinearLayout &ll,
3423+ bool useHWPointOfView) {
3424+ auto inDimNames = llvm::to_vector (ll.getInDimNames ());
3425+ auto *ctx = inDimNames[0 ].getContext ();
3426+ StringAttr kRegister = StringAttr::get (ctx, " register" );
3427+ StringAttr kLane = StringAttr::get (ctx, " lane" );
3428+ StringAttr kWarp = StringAttr::get (ctx, " warp" );
3429+ StringAttr kBlock = StringAttr::get (ctx, " block" );
34353430
3436- LinearLayout ll = toLinearLayout (tensorType);
3437- int64_t tensorSize = product (tensorType.getShape ());
3431+ int64_t tensorSize = ll.getTotalOutDimSize ();
34383432 std::vector<std::string> elementMapping (tensorSize);
34393433 std::vector<std::string> threadMapping;
3434+ auto shape = convertType<int64_t >(llvm::to_vector (ll.getOutDimSizes ()));
34403435 unsigned threadsPerWarp = ll.getInDimSize (kLane );
34413436 unsigned numWarpsPerCTA = ll.getInDimSize (kWarp );
34423437 unsigned numBlocks = ll.getInDimSize (kBlock );
@@ -3456,7 +3451,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
34563451 int stride = 1 ;
34573452 for (int i = outputs.size () - 1 ; i >= 0 ; i--) {
34583453 linearizedIdx += outputs[i].second * stride;
3459- stride *= tensorType. getDimSize (i) ;
3454+ stride *= shape[i] ;
34603455 }
34613456 std::string &value = elementMapping[linearizedIdx];
34623457 if (!value.empty ())
@@ -3476,8 +3471,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
34763471 for (int i = 0 ; i < outputs.size (); i++) {
34773472 if (i > 0 )
34783473 threadInfo += " ," ;
3479- threadInfo +=
3480- paddedString (outputs[i].second , tensorType.getDimSize (i));
3474+ threadInfo += paddedString (outputs[i].second , shape[i]);
34813475 }
34823476 threadInfo += " )" ;
34833477 threadMapping.push_back (threadInfo);
@@ -3488,13 +3482,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
34883482 std::string layoutStr;
34893483 if (!useHWPointOfView) {
34903484 // Printing the threads containing each elements of the tensor.
3491- int rank = tensorType. getRank ();
3485+ int rank = ll. getNumOutDims ();
34923486 bool newLine = true ;
34933487 for (int i = 0 ; i < tensorSize; i++) {
3494- auto indices = delinearizeIndex (i, tensorType. getShape () );
3488+ auto indices = delinearizeIndex (i, shape );
34953489 int numOpenBracket = 0 ;
34963490 for (int j = rank - 1 ; j >= 0 ; j--) {
3497- if (indices[j] % tensorType. getDimSize (j) != 0 )
3491+ if (indices[j] % shape[j] != 0 )
34983492 break ;
34993493 layoutStr += " [" ;
35003494 numOpenBracket++;
@@ -3506,13 +3500,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
35063500 }
35073501
35083502 layoutStr += elementMapping[i];
3509- auto nextIndices = delinearizeIndex (i + 1 , tensorType. getShape () );
3503+ auto nextIndices = delinearizeIndex (i + 1 , shape );
35103504 for (int j = rank - 1 ; j >= 0 ; j--) {
3511- if (nextIndices[j] % tensorType. getDimSize (j) != 0 )
3505+ if (nextIndices[j] % shape[j] != 0 )
35123506 break ;
35133507 layoutStr += " ]" ;
35143508 }
3515- if (nextIndices.back () % tensorType. getShape () .back () == 0 ) {
3509+ if (nextIndices.back () % shape .back () == 0 ) {
35163510 layoutStr += " \n " ;
35173511 newLine = true ;
35183512 } else {
@@ -3578,15 +3572,16 @@ mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) {
35783572std::string mlir::triton::gpu::getLayoutStr (RankedTensorType tensorType,
35793573 bool useHWPointOfView) {
35803574 auto layout = tensorType.getEncoding ();
3575+ LinearLayout ll = triton::gpu::toLinearLayout (tensorType.getShape (), layout);
35813576
35823577 // tensorType is needed later on (e.g., getDimSize(j)), so we still have to
35833578 // pass it as a param
35843579 // TODO: Pass TensorOrMemDesc instead of RankedTensorType in
35853580 // triton-tensor-layout.cpp
35863581 if (mlir::isa<SharedEncodingTrait>(layout)) {
3587- return getSharedLayoutStr (tensorType , useHWPointOfView);
3582+ return getSharedLayoutStr (ll , useHWPointOfView);
35883583 } else if (mlir::isa<DistributedEncodingTrait>(layout)) {
3589- return getDistributedLayoutStr (tensorType , useHWPointOfView);
3584+ return getDistributedLayoutStr (ll , useHWPointOfView);
35903585 }
35913586
35923587 // else unimplemented, return error
0 commit comments