@@ -3058,6 +3058,68 @@ struct TritonGPUInferLayoutInterface
30583058 }
30593059};
30603060
3061+ struct TritonGPUVerifyTensorLayoutInterface
3062+ : public triton::DialectVerifyTensorLayoutInterface {
3063+ using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;
3064+
3065+ LogicalResult verifyTensorLayout (
3066+ Attribute layout, RankedTensorType rankedTy, ModuleOp module ,
3067+ function_ref<InFlightDiagnostic()> makeErr) const override {
3068+ if (isa<triton::gpu::SharedEncodingAttr>(layout))
3069+ return makeErr () << " Shared layout is not allowed on tensor type." ;
3070+ // TODO(jlebar): Currently this only checks blocked layouts, but other
3071+ // layouts also have invariants!
3072+
3073+ // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
3074+ if (auto blocked = dyn_cast<triton::gpu::BlockedEncodingAttr>(layout)) {
3075+ // A different verifier should have checked that the layout itself is
3076+ // valid, including that threads-per-warp has the same rank as
3077+ // warps-per-block etc.
3078+ auto layoutRank = blocked.getThreadsPerWarp ().size ();
3079+ if (layoutRank != rankedTy.getRank ()) {
3080+ return makeErr () << layout << " .\n Layout has rank " << layoutRank
3081+ << " , but the tensor it's attached to has rank "
3082+ << rankedTy.getRank () << " ." ;
3083+ }
3084+
3085+ int moduleThreadsPerWarp =
3086+ triton::gpu::TritonGPUDialect::getThreadsPerWarp (module );
3087+ int64_t layoutThreadsPerWarp = product (blocked.getThreadsPerWarp ());
3088+ if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
3089+ return makeErr () << layout << " .\n Layout has a total of "
3090+ << layoutThreadsPerWarp
3091+ << " threads per warp, but the module specifies "
3092+ << moduleThreadsPerWarp << " threads per warp." ;
3093+ }
3094+
3095+ int moduleWarpsPerCTA =
3096+ triton::gpu::TritonGPUDialect::getNumWarps (module );
3097+ int64_t layoutWarpsPerCTA = product (blocked.getWarpsPerCTA ());
3098+ if (layoutWarpsPerCTA != moduleWarpsPerCTA) {
3099+ return makeErr () << layout << " .\n Layout has a total of "
3100+ << layoutWarpsPerCTA
3101+ << " warps per CTA, but the module specifies "
3102+ << moduleWarpsPerCTA << " warps per CTA." ;
3103+ }
3104+
3105+ if (blocked.getCTALayout ().getCTAsPerCGA ().size () > 0 ) {
3106+ int moduleCTAsPerCGA =
3107+ triton::gpu::TritonGPUDialect::getNumCTAs (module );
3108+ int64_t layoutCTAsPerCGA =
3109+ product (blocked.getCTALayout ().getCTAsPerCGA ());
3110+ if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
3111+ return makeErr () << layout << " .\n Layout has a total of "
3112+ << layoutCTAsPerCGA
3113+ << " CTAs per CGA, but the module specifies "
3114+ << moduleCTAsPerCGA << " CTAs per CGA." ;
3115+ }
3116+ }
3117+ }
3118+
3119+ return success ();
3120+ }
3121+ };
3122+
30613123// ===----------------------------------------------------------------------===//
30623124// Canonicalizer
30633125// ===----------------------------------------------------------------------===//
@@ -3798,6 +3860,7 @@ void TritonGPUDialect::initialize() {
37983860 >();
37993861 addInterfaces<TritonGPUOpAsmInterface>();
38003862 addInterfaces<TritonGPUInferLayoutInterface>();
3863+ addInterfaces<TritonGPUVerifyTensorLayoutInterface>();
38013864
38023865 RankedTensorType::attachInterface<TensorModel>(*getContext ());
38033866 MemDescType::attachInterface<MemDescModel>(*getContext ());
0 commit comments