|
23 | 23 | #include "triton/Tools/LayoutUtils.h"
|
24 | 24 | #include "triton/Tools/LinearLayout.h"
|
25 | 25 | #include "triton/Tools/StrUtil.h"
|
| 26 | +#include "triton/Tools/Sys/GetEnv.hpp" |
26 | 27 | #include "llvm/ADT/SmallSet.h"
|
27 | 28 | #include "llvm/ADT/TypeSwitch.h"
|
28 | 29 | #include "llvm/Support/MathExtras.h"
|
@@ -153,8 +154,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
|
153 | 154 | // Return the order that represents that the batch is in row-major or
|
154 | 155 | // column-major order for a batch of matrices of shape [*, m, n] with
|
155 | 156 | // len(shape) == rank.
|
156 |
| - assert(rank >= 2); |
157 | 157 | SmallVector<unsigned> order(rank);
|
| 158 | + if (rank < 2) { |
| 159 | + return order; |
| 160 | + } |
158 | 161 | std::iota(order.rbegin(), order.rend(), 0);
|
159 | 162 | if (!rowMajor) {
|
160 | 163 | std::swap(order[0], order[1]);
|
@@ -396,6 +399,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
396 | 399 | return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
|
397 | 400 | "order must all have the same rank.";
|
398 | 401 | }
|
| 402 | + if (llvm::any_of(sizePerThread, |
| 403 | + [](unsigned x) { return !llvm::isPowerOf2_64(x); })) { |
| 404 | + return emitError() |
| 405 | + << "Every element in sizePerThread must be a power of two."; |
| 406 | + } |
| 407 | + if (llvm::any_of(threadsPerWarp, |
| 408 | + [](unsigned x) { return !llvm::isPowerOf2_64(x); })) { |
| 409 | + return emitError() |
| 410 | + << "Every element in threadsPerWarp must be a power of two."; |
| 411 | + } |
| 412 | + if (llvm::any_of(warpsPerCTA, |
| 413 | + [](unsigned x) { return !llvm::isPowerOf2_64(x); })) { |
| 414 | + return emitError() |
| 415 | + << "Every element in warpsPerCTA must be a power of two."; |
| 416 | + } |
399 | 417 |
|
400 | 418 | // Empty CTALayout is allowed, but if it's present its rank must match the
|
401 | 419 | // BlockedEncodingAttr's rank.
|
@@ -2246,6 +2264,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
|
2246 | 2264 | SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
|
2247 | 2265 | if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
|
2248 | 2266 | return mma.getRepOrderForOperand(getOpIdx());
|
| 2267 | + } else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) { |
| 2268 | + return to_vector(blocked.getOrder()); |
2249 | 2269 | }
|
2250 | 2270 | llvm::report_fatal_error(
|
2251 | 2271 | "getRepOrder not implemented for DotOperandEncodingAttr");
|
@@ -2958,60 +2978,66 @@ struct TritonGPUVerifyTensorLayoutInterface
|
2958 | 2978 | LogicalResult verifyTensorLayout(
|
2959 | 2979 | Attribute layout, RankedTensorType rankedTy, Operation *op,
|
2960 | 2980 | function_ref<InFlightDiagnostic()> makeErr) const override {
|
2961 |
| - if (isa<triton::gpu::SharedEncodingTrait>(layout)) |
2962 |
| - return makeErr() << "Shared layout is not allowed on tensor type."; |
2963 |
| - // TODO(jlebar): Currently this only checks blocked layouts, but other |
2964 |
| - // layouts also have invariants! |
2965 |
| - |
2966 |
| - // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. |
2967 |
| - if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) { |
2968 |
| - ModuleOp module = op->getParentOfType<ModuleOp>(); |
2969 |
| - |
2970 |
| - // A different verifier should have checked that the layout itself is |
2971 |
| - // valid, including that threads-per-warp has the same rank as |
2972 |
| - // warps-per-block etc. |
2973 |
| - if (blocked.getRank() != rankedTy.getRank()) { |
2974 |
| - return makeErr() << layout << ".\nLayout has rank " << blocked.getRank() |
2975 |
| - << ", but the tensor it's attached to has rank " |
2976 |
| - << rankedTy.getRank() << "."; |
2977 |
| - } |
2978 |
| - |
2979 |
| - int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); |
2980 |
| - int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); |
2981 |
| - if (layoutThreadsPerWarp != moduleThreadsPerWarp) { |
2982 |
| - return makeErr() << layout << ".\nLayout has a total of " |
2983 |
| - << layoutThreadsPerWarp |
2984 |
| - << " threads per warp, but the module specifies " |
2985 |
| - << moduleThreadsPerWarp << " threads per warp."; |
2986 |
| - } |
2987 |
| - |
2988 |
| - std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op); |
2989 |
| - if (!moduleWarpsPerCTA) { |
2990 |
| - return makeErr() |
2991 |
| - << "Could not determine the number of warps per CTA. Operation " |
2992 |
| - "is not in a context with `ttg.num-warps`."; |
2993 |
| - } |
2994 |
| - int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); |
2995 |
| - if (layoutWarpsPerCTA != *moduleWarpsPerCTA) { |
2996 |
| - return makeErr() << layout << ".\nLayout has a total of " |
2997 |
| - << layoutWarpsPerCTA |
2998 |
| - << " warps per CTA, but the context requires " |
2999 |
| - << *moduleWarpsPerCTA << " warps per CTA."; |
3000 |
| - } |
3001 |
| - |
3002 |
| - if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { |
3003 |
| - int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module); |
3004 |
| - int64_t layoutCTAsPerCGA = |
3005 |
| - product(blocked.getCTALayout().getCTAsPerCGA()); |
3006 |
| - if (layoutCTAsPerCGA != moduleCTAsPerCGA) { |
3007 |
| - return makeErr() << layout << ".\nLayout has a total of " |
3008 |
| - << layoutCTAsPerCGA |
3009 |
| - << " CTAs per CGA, but the module specifies " |
3010 |
| - << moduleCTAsPerCGA << " CTAs per CGA."; |
3011 |
| - } |
3012 |
| - } |
| 2981 | + auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout); |
| 2982 | + if (!distr) |
| 2983 | + return makeErr() |
| 2984 | + << "Non-distributed layout is not allowed in tensor type."; |
| 2985 | + if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) |
| 2986 | + return success(); |
| 2987 | + auto rank = distr.getRepOrder().size(); |
| 2988 | + if (rank != rankedTy.getRank()) |
| 2989 | + return makeErr() << "Layout has rank " << rank |
| 2990 | + << ", but the tensor it's attached to has rank " |
| 2991 | + << rankedTy.getRank() << "."; |
| 2992 | + if (llvm::any_of(rankedTy.getShape(), |
| 2993 | + [](int64_t i) { return !llvm::isPowerOf2_64(i); })) { |
| 2994 | + return makeErr() << "Layout has shape " << rankedTy.getShape() |
| 2995 | + << ", but the tensor it's attached to has shape " |
| 2996 | + << rankedTy.getShape() |
| 2997 | + << " which is not a power of two."; |
| 2998 | + } |
| 2999 | + auto ll = toLinearLayout(rankedTy); |
| 3000 | + ModuleOp module = op->getParentOfType<ModuleOp>(); |
| 3001 | + |
| 3002 | + // Number of threads per warp. |
| 3003 | + auto kLane = StringAttr::get(module.getContext(), "lane"); |
| 3004 | + int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); |
| 3005 | + // FIXME: ll.getInDimSize(kLane) does not return the correct threads per |
| 3006 | + // warp. https://github.com/intel/intel-xpu-backend-for-triton/issues/4861 |
| 3007 | + unsigned layoutThreadsPerWarp = ll.getInDimSize(kLane); |
| 3008 | + if (auto dotOperandLayout = |
| 3009 | + dyn_cast<DotOperandEncodingAttr>(rankedTy.getEncoding())) |
| 3010 | + if (auto dpasLayout = |
| 3011 | + dyn_cast<intel::DpasEncodingAttr>(dotOperandLayout.getParent())) |
| 3012 | + layoutThreadsPerWarp = dpasLayout.getThreadsPerWarp(); |
| 3013 | + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { |
| 3014 | + return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane) |
| 3015 | + << " threads per warp, but the module specifies " |
| 3016 | + << moduleThreadsPerWarp << " threads per warp."; |
| 3017 | + } |
| 3018 | + |
| 3019 | + // Number of warps per CTA. |
| 3020 | + std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op); |
| 3021 | + if (!moduleWarpsPerCTA) { |
| 3022 | + return makeErr() |
| 3023 | + << "Could not determine the number of warps per CTA. Operation " |
| 3024 | + "is not in a context with `ttg.num-warps`."; |
| 3025 | + } |
| 3026 | + auto kWarp = StringAttr::get(module.getContext(), "warp"); |
| 3027 | + if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) { |
| 3028 | + return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp) |
| 3029 | + << " warps per CTA, but the context requires " |
| 3030 | + << *moduleWarpsPerCTA << " warps per CTA."; |
| 3031 | + } |
| 3032 | + |
| 3033 | + // Number of CTAs per CGA. |
| 3034 | + auto kBlock = StringAttr::get(module.getContext(), "block"); |
| 3035 | + int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module); |
| 3036 | + if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) { |
| 3037 | + return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock) |
| 3038 | + << " CTAs per CGA, but the context requires " |
| 3039 | + << moduleCTAsPerCGA << " CTAs per CGA."; |
3013 | 3040 | }
|
3014 |
| - |
3015 | 3041 | return success();
|
3016 | 3042 | }
|
3017 | 3043 | };
|
|
0 commit comments