|
1 | 1 | #include "triton/Dialect/Triton/IR/Dialect.h" |
2 | 2 |
|
| 3 | +#include <cstdint> |
3 | 4 | #include <numeric> |
4 | 5 |
|
5 | 6 | #include "mlir/IR/DialectImplementation.h" |
@@ -3131,8 +3132,124 @@ static std::string paddedString(int value, int max) { |
3131 | 3132 | return str; |
3132 | 3133 | } |
3133 | 3134 |
|
3134 | | -std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, |
3135 | | - bool useHWPointOfView) { |
| 3135 | +std::string getSharedLayoutStr(RankedTensorType tensorType, |
| 3136 | + bool useHWPointOfView) { |
| 3137 | + auto layout = tensorType.getEncoding(); |
| 3138 | + if (!layout) |
| 3139 | + return ""; |
| 3140 | + |
| 3141 | + std::optional<LinearLayout> ll = |
| 3142 | + triton::gpu::toLinearLayout(tensorType.getShape(), layout); |
| 3143 | + if (!ll.has_value()) |
| 3144 | + llvm::report_fatal_error("Failed to convert layout to linear layout"); |
| 3145 | + |
| 3146 | + StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset"); |
| 3147 | + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); |
| 3148 | + int64_t tensorSize = product(tensorType.getShape()); |
| 3149 | + unsigned numBlocks = getNumCTAs(layout); |
| 3150 | + int32_t blockSize = tensorSize / numBlocks; |
| 3151 | + |
| 3152 | + // elementMapping is for the non-hw layout, offsetMapping for hw-layout |
| 3153 | + std::vector<std::string> elementMapping(tensorSize); |
| 3154 | + std::vector<std::string> offsetMapping; |
| 3155 | + |
| 3156 | + // Shared layouts are a mapping of (block, offset) --> (...) |
| 3157 | + |
| 3158 | + // We can just use a single int to index into elementMapping because |
| 3159 | + // the 'swizzle' operation rearranges the indicies---and we want to keep it |
| 3160 | + // that way |
| 3161 | + int32_t idx = 0; |
| 3162 | + // Enumerate all the offsets for each block |
| 3163 | + for (int32_t block = 0; block < numBlocks; block++) { |
| 3164 | + for (int32_t offset = 0; offset < blockSize; offset++) { |
| 3165 | + SmallVector<std::pair<StringAttr, int32_t>> inputs = { |
| 3166 | + {kBlock, block}, |
| 3167 | + {kOffset, offset}, |
| 3168 | + }; |
| 3169 | + |
| 3170 | + SmallVector<std::pair<StringAttr, int32_t>> outputs = ll->apply(inputs); |
| 3171 | + |
| 3172 | + std::string sharedInfo = "("; |
| 3173 | + std::string &value = elementMapping[idx]; |
| 3174 | + |
| 3175 | + if (!value.empty()) |
| 3176 | + value += "|"; |
| 3177 | + |
| 3178 | + value += "("; |
| 3179 | + // We can build up both strings (for hw/non-hw layouts) concurrently |
| 3180 | + for (int i = 0; i < outputs.size(); i++) { |
| 3181 | + // Based on the formatting from LinearLayout::toString, the format for |
| 3182 | + // the hw layout is slightly different. HW layouts use "," vs ":". |
| 3183 | + if (i > 0) { |
| 3184 | + sharedInfo += ","; |
| 3185 | + value += ":"; |
| 3186 | + } |
| 3187 | + auto index = paddedString(outputs[i].second, tensorType.getDimSize(i)); |
| 3188 | + sharedInfo += index; |
| 3189 | + value += index; |
| 3190 | + } |
| 3191 | + value += ")"; |
| 3192 | + sharedInfo += ")"; |
| 3193 | + |
| 3194 | + offsetMapping.push_back(sharedInfo); |
| 3195 | + |
| 3196 | + idx++; |
| 3197 | + } |
| 3198 | + } |
| 3199 | + |
| 3200 | + std::string layoutStr; |
| 3201 | + |
| 3202 | + if (!useHWPointOfView) { |
| 3203 | + int rank = tensorType.getRank(); |
| 3204 | + bool newLine = true; |
| 3205 | + for (int i = 0; i < tensorSize; i++) { |
| 3206 | + auto indices = delinearizeIndex(i, tensorType.getShape()); |
| 3207 | + int numOpenBracket = 0; |
| 3208 | + for (int j = rank - 1; j >= 0; j--) { |
| 3209 | + if (indices[j] % tensorType.getDimSize(j) != 0) |
| 3210 | + break; |
| 3211 | + layoutStr += "["; |
| 3212 | + numOpenBracket++; |
| 3213 | + } |
| 3214 | + if (newLine) { |
| 3215 | + for (int j = 0; j < rank - numOpenBracket; j++) |
| 3216 | + layoutStr += " "; |
| 3217 | + newLine = false; |
| 3218 | + } |
| 3219 | + |
| 3220 | + layoutStr += elementMapping[i]; |
| 3221 | + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); |
| 3222 | + for (int j = rank - 1; j >= 0; j--) { |
| 3223 | + if (nextIndices[j] % tensorType.getDimSize(j) != 0) |
| 3224 | + break; |
| 3225 | + layoutStr += "]"; |
| 3226 | + } |
| 3227 | + if (nextIndices.back() % tensorType.getShape().back() == 0) { |
| 3228 | + layoutStr += "\n"; |
| 3229 | + newLine = true; |
| 3230 | + } else { |
| 3231 | + layoutStr += ","; |
| 3232 | + } |
| 3233 | + } |
| 3234 | + } else { |
| 3235 | + // For the HW view here, print the (block, offset) --> (r,c) mapping |
| 3236 | + uint32_t idx = 0; |
| 3237 | + for (int32_t block = 0; block < numBlocks; block++) { |
| 3238 | + layoutStr += "Block: " + std::to_string(block) + ":\n"; |
| 3239 | + for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) { |
| 3240 | + layoutStr += "Offset: " + std::to_string(offset) + " -> "; |
| 3241 | + layoutStr += offsetMapping[idx]; |
| 3242 | + layoutStr += "\n"; |
| 3243 | + idx++; |
| 3244 | + } |
| 3245 | + } |
| 3246 | + } |
| 3247 | + |
| 3248 | + return layoutStr; |
| 3249 | +} |
| 3250 | + |
| 3251 | +std::string getDistributedLayoutStr(RankedTensorType tensorType, |
| 3252 | + bool useHWPointOfView) { |
3136 | 3253 | auto layout = tensorType.getEncoding(); |
3137 | 3254 | if (!layout) |
3138 | 3255 | return ""; |
@@ -3199,7 +3316,7 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, |
3199 | 3316 | } |
3200 | 3317 | std::string layoutStr; |
3201 | 3318 | if (!useHWPointOfView) { |
3202 | | - // Printing the threads containning each elements of the tensor. |
| 3319 | + // Printing the threads containing each elements of the tensor. |
3203 | 3320 | int rank = tensorType.getRank(); |
3204 | 3321 | bool newLine = true; |
3205 | 3322 | for (int i = 0; i < tensorSize; i++) { |
@@ -3257,6 +3374,24 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, |
3257 | 3374 | return layoutStr; |
3258 | 3375 | } |
3259 | 3376 |
|
| 3377 | +std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, |
| 3378 | + bool useHWPointOfView) { |
| 3379 | + auto layout = tensorType.getEncoding(); |
| 3380 | + |
| 3381 | + // tensorType is needed later on (e.g., getDimSize(j)), so we still have to |
| 3382 | + // pass it as a param |
| 3383 | + if (auto sharedLayout = mlir::dyn_cast<SharedEncodingAttr>(layout)) { |
| 3384 | + return getSharedLayoutStr(tensorType, useHWPointOfView); |
| 3385 | + } else if (auto distributedLayout = |
| 3386 | + mlir::dyn_cast<DistributedEncodingTrait>(layout)) { |
| 3387 | + return getDistributedLayoutStr(tensorType, useHWPointOfView); |
| 3388 | + } |
| 3389 | + |
| 3390 | + // else unimplemented, return error |
| 3391 | + llvm::report_fatal_error("Unimplemented usage of getLayoutStr"); |
| 3392 | + return ""; |
| 3393 | +} |
| 3394 | + |
3260 | 3395 | void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) { |
3261 | 3396 | llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false); |
3262 | 3397 | } |
|
0 commit comments