@@ -143,10 +143,8 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
143143 // Return the order that represents that the batch is in row-major or
144144 // column-major order for a batch of matrices of shape [*, m, n] with
145145 // len(shape) == rank.
146+ assert (rank >= 2 );
146147 SmallVector<unsigned > order (rank);
147- if (rank < 2 ) {
148- return order;
149- }
150148 std::iota (order.rbegin (), order.rend (), 0 );
151149 if (!rowMajor) {
152150 std::swap (order[0 ], order[1 ]);
@@ -399,21 +397,6 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
399397 return emitError () << " sizePerThread, threadsPerWarp, warpsPerCTA, and "
400398 " order must all have the same rank." ;
401399 }
402- if (llvm::any_of (sizePerThread,
403- [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
404- return emitError ()
405- << " Every element in sizePerThread must be a power of two." ;
406- }
407- if (llvm::any_of (threadsPerWarp,
408- [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
409- return emitError ()
410- << " Every element in threadsPerWarp must be a power of two." ;
411- }
412- if (llvm::any_of (warpsPerCTA,
413- [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
414- return emitError ()
415- << " Every element in warpsPerCTA must be a power of two." ;
416- }
417400
418401 // Empty CTALayout is allowed, but if it's present its rank must match the
419402 // BlockedEncodingAttr's rank.
@@ -2013,8 +1996,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20131996SmallVector<unsigned > DotOperandEncodingAttr::getRepOrder () const {
20141997 if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent ())) {
20151998 return mma.getRepOrderForOperand (getOpIdx ());
2016- } else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
2017- return to_vector (blocked.getOrder ());
20181999 }
20192000 llvm::report_fatal_error (
20202001 " getRepOrder not implemented for DotOperandEncodingAttr" );
@@ -2715,56 +2696,60 @@ struct TritonGPUVerifyTensorLayoutInterface
27152696 LogicalResult verifyTensorLayout (
27162697 Attribute layout, RankedTensorType rankedTy, Operation *op,
27172698 function_ref<InFlightDiagnostic()> makeErr) const override {
2718- auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2719- if (!distr)
2720- return makeErr ()
2721- << " Non-distributed layout is not allowed in tensor type." ;
2722- auto rank = distr.getRepOrder ().size ();
2723- if (rank != rankedTy.getRank ())
2724- return makeErr () << " Layout has rank " << rank
2725- << " , but the tensor it's attached to has rank "
2726- << rankedTy.getRank () << " ." ;
2727- if (llvm::any_of (rankedTy.getShape (),
2728- [](int64_t i) { return !llvm::isPowerOf2_64 (i); })) {
2729- return makeErr () << " Layout has shape " << rankedTy.getShape ()
2730- << " , but the tensor it's attached to has shape "
2731- << rankedTy.getShape ()
2732- << " which is not a power of two." ;
2733- }
2734- auto ll = toLinearLayout (rankedTy.getShape (), layout);
2735- ModuleOp module = op->getParentOfType <ModuleOp>();
2736-
2737- // Number of threads per warp.
2738- auto kLane = StringAttr::get (module .getContext (), " lane" );
2739- int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2740- if (ll.getInDimSize (kLane ) != moduleThreadsPerWarp) {
2741- return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kLane )
2742- << " threads per warp, but the module specifies "
2743- << moduleThreadsPerWarp << " threads per warp." ;
2744- }
2745-
2746- // Number of warps per CTA.
2747- std::optional<int > moduleWarpsPerCTA = maybeLookupNumWarps (op);
2748- if (!moduleWarpsPerCTA) {
2749- return makeErr ()
2750- << " Could not determine the number of warps per CTA. Operation "
2751- " is not in a context with `ttg.num-warps`." ;
2752- }
2753- auto kWarp = StringAttr::get (module .getContext (), " warp" );
2754- if (ll.getInDimSize (kWarp ) != *moduleWarpsPerCTA) {
2755- return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kWarp )
2756- << " warps per CTA, but the context requires "
2757- << *moduleWarpsPerCTA << " warps per CTA." ;
2758- }
2759-
2760- // Number of CTAs per CGA.
2761- auto kBlock = StringAttr::get (module .getContext (), " block" );
2762- int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs (module );
2763- if (ll.getInDimSize (kBlock ) != moduleCTAsPerCGA) {
2764- return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kBlock )
2765- << " CTAs per CGA, but the context requires "
2766- << moduleCTAsPerCGA << " CTAs per CGA." ;
2699+ if (isa<triton::gpu::SharedEncodingTrait>(layout))
2700+ return makeErr () << " Shared layout is not allowed on tensor type." ;
2701+ // TODO(jlebar): Currently this only checks blocked layouts, but other
2702+ // layouts also have invariants!
2703+
2704+ // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2705+ if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2706+ ModuleOp module = op->getParentOfType <ModuleOp>();
2707+
2708+ // A different verifier should have checked that the layout itself is
2709+ // valid, including that threads-per-warp has the same rank as
2710+ // warps-per-block etc.
2711+ if (blocked.getRank () != rankedTy.getRank ()) {
2712+ return makeErr () << layout << " .\n Layout has rank " << blocked.getRank ()
2713+ << " , but the tensor it's attached to has rank "
2714+ << rankedTy.getRank () << " ." ;
2715+ }
2716+
2717+ int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2718+ int64_t layoutThreadsPerWarp = product (blocked.getThreadsPerWarp ());
2719+ if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2720+ return makeErr () << layout << " .\n Layout has a total of "
2721+ << layoutThreadsPerWarp
2722+ << " threads per warp, but the module specifies "
2723+ << moduleThreadsPerWarp << " threads per warp." ;
2724+ }
2725+
2726+ std::optional<int > moduleWarpsPerCTA = maybeLookupNumWarps (op);
2727+ if (!moduleWarpsPerCTA) {
2728+ return makeErr ()
2729+ << " Could not determine the number of warps per CTA. Operation "
2730+ " is not in a context with `ttg.num-warps`." ;
2731+ }
2732+ int64_t layoutWarpsPerCTA = product (blocked.getWarpsPerCTA ());
2733+ if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2734+ return makeErr () << layout << " .\n Layout has a total of "
2735+ << layoutWarpsPerCTA
2736+ << " warps per CTA, but the context requires "
2737+ << *moduleWarpsPerCTA << " warps per CTA." ;
2738+ }
2739+
2740+ if (blocked.getCTALayout ().getCTAsPerCGA ().size () > 0 ) {
2741+ int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs (module );
2742+ int64_t layoutCTAsPerCGA =
2743+ product (blocked.getCTALayout ().getCTAsPerCGA ());
2744+ if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
2745+ return makeErr () << layout << " .\n Layout has a total of "
2746+ << layoutCTAsPerCGA
2747+ << " CTAs per CGA, but the module specifies "
2748+ << moduleCTAsPerCGA << " CTAs per CGA." ;
2749+ }
2750+ }
27672751 }
2752+
27682753 return success ();
27692754 }
27702755};
0 commit comments