Skip to content

Commit 5f77e8c

Browse files
SamGinzburgSam Ginzburg
andauthored
[triton][tool] Add support for printing shared memory layouts in the triton-tensor-layout tool (#4839)
The `triton-tensor-layout ` CLI [tool](triton-lang/triton#4486) currently only supports printing distributed layouts, although in the past there was support for printing shared memory layouts as well [previous commit here](triton-lang/triton@cc3ebcf#diff-85ef554db39720672d1d07edec05a0a596614927985cd2ed25a1888a3686d713L332). This PR adds support for printing shared memory layouts again (using linear layouts). I've added several unit tests based on the old tests to ensure correctness. Example usage: ``` ./triton-tensor-layout -l "#triton_gpu.shared<{vec = 4, perPhase = 32, maxPhase = 1, order = [1,0], hasLeadingOffset = false}>" -t "tensor<8x32xf16>" ``` ================================================================ The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [ ] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Sam Ginzburg <[email protected]>
1 parent 819338d commit 5f77e8c

File tree

2 files changed

+497
-3
lines changed

2 files changed

+497
-3
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

3+
#include <cstdint>
34
#include <numeric>
45

56
#include "mlir/IR/DialectImplementation.h"
@@ -3131,8 +3132,124 @@ static std::string paddedString(int value, int max) {
31313132
return str;
31323133
}
31333134

3134-
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
3135-
bool useHWPointOfView) {
3135+
std::string getSharedLayoutStr(RankedTensorType tensorType,
3136+
bool useHWPointOfView) {
3137+
auto layout = tensorType.getEncoding();
3138+
if (!layout)
3139+
return "";
3140+
3141+
std::optional<LinearLayout> ll =
3142+
triton::gpu::toLinearLayout(tensorType.getShape(), layout);
3143+
if (!ll.has_value())
3144+
llvm::report_fatal_error("Failed to convert layout to linear layout");
3145+
3146+
StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset");
3147+
StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block");
3148+
int64_t tensorSize = product(tensorType.getShape());
3149+
unsigned numBlocks = getNumCTAs(layout);
3150+
int32_t blockSize = tensorSize / numBlocks;
3151+
3152+
// elementMapping is for the non-hw layout, offsetMapping for hw-layout
3153+
std::vector<std::string> elementMapping(tensorSize);
3154+
std::vector<std::string> offsetMapping;
3155+
3156+
// Shared layouts are a mapping of (block, offset) --> (...)
3157+
3158+
// We can just use a single int to index into elementMapping because
3159+
// the 'swizzle' operation rearranges the indicies---and we want to keep it
3160+
// that way
3161+
int32_t idx = 0;
3162+
// Enumerate all the offsets for each block
3163+
for (int32_t block = 0; block < numBlocks; block++) {
3164+
for (int32_t offset = 0; offset < blockSize; offset++) {
3165+
SmallVector<std::pair<StringAttr, int32_t>> inputs = {
3166+
{kBlock, block},
3167+
{kOffset, offset},
3168+
};
3169+
3170+
SmallVector<std::pair<StringAttr, int32_t>> outputs = ll->apply(inputs);
3171+
3172+
std::string sharedInfo = "(";
3173+
std::string &value = elementMapping[idx];
3174+
3175+
if (!value.empty())
3176+
value += "|";
3177+
3178+
value += "(";
3179+
// We can build up both strings (for hw/non-hw layouts) concurrently
3180+
for (int i = 0; i < outputs.size(); i++) {
3181+
// Based on the formatting from LinearLayout::toString, the format for
3182+
// the hw layout is slightly different. HW layouts use "," vs ":".
3183+
if (i > 0) {
3184+
sharedInfo += ",";
3185+
value += ":";
3186+
}
3187+
auto index = paddedString(outputs[i].second, tensorType.getDimSize(i));
3188+
sharedInfo += index;
3189+
value += index;
3190+
}
3191+
value += ")";
3192+
sharedInfo += ")";
3193+
3194+
offsetMapping.push_back(sharedInfo);
3195+
3196+
idx++;
3197+
}
3198+
}
3199+
3200+
std::string layoutStr;
3201+
3202+
if (!useHWPointOfView) {
3203+
int rank = tensorType.getRank();
3204+
bool newLine = true;
3205+
for (int i = 0; i < tensorSize; i++) {
3206+
auto indices = delinearizeIndex(i, tensorType.getShape());
3207+
int numOpenBracket = 0;
3208+
for (int j = rank - 1; j >= 0; j--) {
3209+
if (indices[j] % tensorType.getDimSize(j) != 0)
3210+
break;
3211+
layoutStr += "[";
3212+
numOpenBracket++;
3213+
}
3214+
if (newLine) {
3215+
for (int j = 0; j < rank - numOpenBracket; j++)
3216+
layoutStr += " ";
3217+
newLine = false;
3218+
}
3219+
3220+
layoutStr += elementMapping[i];
3221+
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
3222+
for (int j = rank - 1; j >= 0; j--) {
3223+
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
3224+
break;
3225+
layoutStr += "]";
3226+
}
3227+
if (nextIndices.back() % tensorType.getShape().back() == 0) {
3228+
layoutStr += "\n";
3229+
newLine = true;
3230+
} else {
3231+
layoutStr += ",";
3232+
}
3233+
}
3234+
} else {
3235+
// For the HW view here, print the (block, offset) --> (r,c) mapping
3236+
uint32_t idx = 0;
3237+
for (int32_t block = 0; block < numBlocks; block++) {
3238+
layoutStr += "Block: " + std::to_string(block) + ":\n";
3239+
for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) {
3240+
layoutStr += "Offset: " + std::to_string(offset) + " -> ";
3241+
layoutStr += offsetMapping[idx];
3242+
layoutStr += "\n";
3243+
idx++;
3244+
}
3245+
}
3246+
}
3247+
3248+
return layoutStr;
3249+
}
3250+
3251+
std::string getDistributedLayoutStr(RankedTensorType tensorType,
3252+
bool useHWPointOfView) {
31363253
auto layout = tensorType.getEncoding();
31373254
if (!layout)
31383255
return "";
@@ -3199,7 +3316,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
31993316
}
32003317
std::string layoutStr;
32013318
if (!useHWPointOfView) {
3202-
// Printing the threads containning each elements of the tensor.
3319+
// Printing the threads containing each elements of the tensor.
32033320
int rank = tensorType.getRank();
32043321
bool newLine = true;
32053322
for (int i = 0; i < tensorSize; i++) {
@@ -3257,6 +3374,24 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
32573374
return layoutStr;
32583375
}
32593376

3377+
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
3378+
bool useHWPointOfView) {
3379+
auto layout = tensorType.getEncoding();
3380+
3381+
// tensorType is needed later on (e.g., getDimSize(j)), so we still have to
3382+
// pass it as a param
3383+
if (auto sharedLayout = mlir::dyn_cast<SharedEncodingAttr>(layout)) {
3384+
return getSharedLayoutStr(tensorType, useHWPointOfView);
3385+
} else if (auto distributedLayout =
3386+
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
3387+
return getDistributedLayoutStr(tensorType, useHWPointOfView);
3388+
}
3389+
3390+
// else unimplemented, return error
3391+
llvm::report_fatal_error("Unimplemented usage of getLayoutStr");
3392+
return "";
3393+
}
3394+
32603395
void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) {
32613396
llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false);
32623397
}

0 commit comments

Comments
 (0)