Skip to content

Commit a4f1854

Browse files
Move tensor layout verifier impl into dialect interface (#5312)
As discussed in #5302 this moves the TTG tensor layout verification into a dialect interface while keeping the op verification inside the trait verifier. It seemed like the op trait was still the best way to run the verifier on all operations, with the business logic of verifying specific attributes moved to the relevant dialects. The signature of `DialectVerifyTensorLayoutInterface::verifyTensorLayout` is very specific to the TritonGPU dialect attributes we want to verify - we could also probably simplify the signature to just take a `mlir::Value` and the layout attribute to verify but I wanted to keep things close to the initial implementation to start. @Jokeren @ThomasRaoux @peterbell10 please let me know if this is in the right direction, I am happy to make changes if needed. Co-authored-by: Lei Zhang <[email protected]>
1 parent 00cc5d0 commit a4f1854

File tree

3 files changed

+79
-49
lines changed

3 files changed

+79
-49
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ class DialectInferLayoutInterface
7878
Attribute operandEncodingB) const = 0;
7979
};
8080

81+
class DialectVerifyTensorLayoutInterface
82+
: public DialectInterface::Base<DialectVerifyTensorLayoutInterface> {
83+
public:
84+
DialectVerifyTensorLayoutInterface(Dialect *dialect) : Base(dialect) {}
85+
86+
virtual LogicalResult
87+
verifyTensorLayout(Attribute layout, RankedTensorType type, ModuleOp module,
88+
function_ref<InFlightDiagnostic()> emitError) const = 0;
89+
};
90+
8191
} // namespace triton
8292
} // namespace mlir
8393

lib/Dialect/Triton/IR/Traits.cpp

Lines changed: 6 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
#include "mlir/IR/TypeUtilities.h"
66
#include "triton/Dialect/Triton/IR/Types.h"
77
#include "triton/Dialect/Triton/IR/Utility.h"
8-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
98
#include "llvm/Support/ErrorHandling.h"
109

1110
using namespace mlir;
12-
namespace ttg = mlir::triton::gpu;
1311

1412
static LogicalResult verifySameEncoding(Type typeA, Type typeB,
1513
bool allowTensorPointerType) {
@@ -118,53 +116,12 @@ LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) {
118116
if (!layout)
119117
return success();
120118

121-
if (isa<ttg::SharedEncodingAttr>(layout))
122-
return makeErr() << "Shared layout is not allowed on tensor type.";
123-
// TODO(jlebar): Currently this only checks blocked layouts, but other
124-
// layouts also have invariants!
125-
126-
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
127-
if (auto blocked = dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
128-
// A different verifier should have checked that the layout itself is
129-
// valid, including that threads-per-warp has the same rank as
130-
// warps-per-block etc.
131-
auto layoutRank = blocked.getThreadsPerWarp().size();
132-
if (layoutRank != rankedTy.getRank()) {
133-
return makeErr() << layout << ".\nLayout has rank " << layoutRank
134-
<< ", but the tensor it's attached to has rank "
135-
<< rankedTy.getRank() << ".";
136-
}
137-
138-
int moduleThreadsPerWarp =
139-
ttg::TritonGPUDialect::getThreadsPerWarp(module);
140-
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
141-
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
142-
return makeErr() << layout << ".\nLayout has a total of "
143-
<< layoutThreadsPerWarp
144-
<< " threads per warp, but the module specifies "
145-
<< moduleThreadsPerWarp << " threads per warp.";
146-
}
147-
148-
int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module);
149-
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
150-
if (layoutWarpsPerCTA != moduleWarpsPerCTA) {
151-
return makeErr() << layout << ".\nLayout has a total of "
152-
<< layoutWarpsPerCTA
153-
<< " warps per CTA, but the module specifies "
154-
<< moduleWarpsPerCTA << " warps per CTA.";
155-
}
156-
157-
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
158-
int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module);
159-
int64_t layoutCTAsPerCGA =
160-
product(blocked.getCTALayout().getCTAsPerCGA());
161-
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
162-
return makeErr() << layout << ".\nLayout has a total of "
163-
<< layoutCTAsPerCGA
164-
<< " CTAs per CGA, but the module specifies "
165-
<< moduleCTAsPerCGA << " CTAs per CGA.";
166-
}
167-
}
119+
Dialect &dialect = layout.getDialect();
120+
auto verifyLayoutInterface =
121+
dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect);
122+
if (verifyLayoutInterface) {
123+
return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, module,
124+
makeErr);
168125
}
169126

170127
return success();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,68 @@ struct TritonGPUInferLayoutInterface
30093009
}
30103010
};
30113011

3012+
struct TritonGPUVerifyTensorLayoutInterface
3013+
: public triton::DialectVerifyTensorLayoutInterface {
3014+
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;
3015+
3016+
LogicalResult verifyTensorLayout(
3017+
Attribute layout, RankedTensorType rankedTy, ModuleOp module,
3018+
function_ref<InFlightDiagnostic()> makeErr) const override {
3019+
if (isa<triton::gpu::SharedEncodingAttr>(layout))
3020+
return makeErr() << "Shared layout is not allowed on tensor type.";
3021+
// TODO(jlebar): Currently this only checks blocked layouts, but other
3022+
// layouts also have invariants!
3023+
3024+
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
3025+
if (auto blocked = dyn_cast<triton::gpu::BlockedEncodingAttr>(layout)) {
3026+
// A different verifier should have checked that the layout itself is
3027+
// valid, including that threads-per-warp has the same rank as
3028+
// warps-per-block etc.
3029+
auto layoutRank = blocked.getThreadsPerWarp().size();
3030+
if (layoutRank != rankedTy.getRank()) {
3031+
return makeErr() << layout << ".\nLayout has rank " << layoutRank
3032+
<< ", but the tensor it's attached to has rank "
3033+
<< rankedTy.getRank() << ".";
3034+
}
3035+
3036+
int moduleThreadsPerWarp =
3037+
triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
3038+
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
3039+
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
3040+
return makeErr() << layout << ".\nLayout has a total of "
3041+
<< layoutThreadsPerWarp
3042+
<< " threads per warp, but the module specifies "
3043+
<< moduleThreadsPerWarp << " threads per warp.";
3044+
}
3045+
3046+
int moduleWarpsPerCTA =
3047+
triton::gpu::TritonGPUDialect::getNumWarps(module);
3048+
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
3049+
if (layoutWarpsPerCTA != moduleWarpsPerCTA) {
3050+
return makeErr() << layout << ".\nLayout has a total of "
3051+
<< layoutWarpsPerCTA
3052+
<< " warps per CTA, but the module specifies "
3053+
<< moduleWarpsPerCTA << " warps per CTA.";
3054+
}
3055+
3056+
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
3057+
int moduleCTAsPerCGA =
3058+
triton::gpu::TritonGPUDialect::getNumCTAs(module);
3059+
int64_t layoutCTAsPerCGA =
3060+
product(blocked.getCTALayout().getCTAsPerCGA());
3061+
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
3062+
return makeErr() << layout << ".\nLayout has a total of "
3063+
<< layoutCTAsPerCGA
3064+
<< " CTAs per CGA, but the module specifies "
3065+
<< moduleCTAsPerCGA << " CTAs per CGA.";
3066+
}
3067+
}
3068+
}
3069+
3070+
return success();
3071+
}
3072+
};
3073+
30123074
//===----------------------------------------------------------------------===//
30133075
// Canonicalizer
30143076
//===----------------------------------------------------------------------===//
@@ -3748,6 +3810,7 @@ void TritonGPUDialect::initialize() {
37483810
>();
37493811
addInterfaces<TritonGPUOpAsmInterface>();
37503812
addInterfaces<TritonGPUInferLayoutInterface>();
3813+
addInterfaces<TritonGPUVerifyTensorLayoutInterface>();
37513814

37523815
RankedTensorType::attachInterface<TensorModel>(*getContext());
37533816
MemDescType::attachInterface<MemDescModel>(*getContext());

0 commit comments

Comments
 (0)