@@ -140,8 +140,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
140140 // Return the order that represents that the batch is in row-major or
141141 // column-major order for a batch of matrices of shape [*, m, n] with
142142 // len(shape) == rank.
143- assert (rank >= 2 );
144143 SmallVector<unsigned > order (rank);
144+ if (rank < 2 ) {
145+ return order;
146+ }
145147 std::iota (order.rbegin (), order.rend (), 0 );
146148 if (!rowMajor) {
147149 std::swap (order[0 ], order[1 ]);
@@ -394,6 +396,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
394396 return emitError () << " sizePerThread, threadsPerWarp, warpsPerCTA, and "
395397 " order must all have the same rank." ;
396398 }
399+ if (llvm::any_of (sizePerThread,
400+ [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
401+ return emitError ()
402+ << " Every element in sizePerThread must be a power of two." ;
403+ }
404+ if (llvm::any_of (threadsPerWarp,
405+ [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
406+ return emitError ()
407+ << " Every element in threadsPerWarp must be a power of two." ;
408+ }
409+ if (llvm::any_of (warpsPerCTA,
410+ [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
411+ return emitError ()
412+ << " Every element in warpsPerCTA must be a power of two." ;
413+ }
397414
398415 // Empty CTALayout is allowed, but if it's present its rank must match the
399416 // BlockedEncodingAttr's rank.
@@ -1963,6 +1980,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19631980SmallVector<unsigned > DotOperandEncodingAttr::getRepOrder () const {
19641981 if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent ())) {
19651982 return mma.getRepOrderForOperand (getOpIdx ());
1983+ } else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
1984+ return to_vector (blocked.getOrder ());
19661985 }
19671986 llvm::report_fatal_error (
19681987 " getRepOrder not implemented for DotOperandEncodingAttr" );
@@ -2660,60 +2679,56 @@ struct TritonGPUVerifyTensorLayoutInterface
26602679 LogicalResult verifyTensorLayout (
26612680 Attribute layout, RankedTensorType rankedTy, Operation *op,
26622681 function_ref<InFlightDiagnostic()> makeErr) const override {
2663- if (isa<triton::gpu::SharedEncodingTrait>(layout))
2664- return makeErr () << " Shared layout is not allowed on tensor type." ;
2665- // TODO(jlebar): Currently this only checks blocked layouts, but other
2666- // layouts also have invariants!
2667-
2668- // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2669- if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2670- ModuleOp module = op->getParentOfType <ModuleOp>();
2671-
2672- // A different verifier should have checked that the layout itself is
2673- // valid, including that threads-per-warp has the same rank as
2674- // warps-per-block etc.
2675- if (blocked.getRank () != rankedTy.getRank ()) {
2676- return makeErr () << layout << " .\n Layout has rank " << blocked.getRank ()
2677- << " , but the tensor it's attached to has rank "
2678- << rankedTy.getRank () << " ." ;
2679- }
2680-
2681- int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2682- int64_t layoutThreadsPerWarp = product (blocked.getThreadsPerWarp ());
2683- if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2684- return makeErr () << layout << " .\n Layout has a total of "
2685- << layoutThreadsPerWarp
2686- << " threads per warp, but the module specifies "
2687- << moduleThreadsPerWarp << " threads per warp." ;
2688- }
2689-
2690- std::optional<int > moduleWarpsPerCTA = maybeLookupNumWarps (op);
2691- if (!moduleWarpsPerCTA) {
2692- return makeErr ()
2693- << " Could not determine the number of warps per CTA. Operation "
2694- " is not in a context with `ttg.num-warps`." ;
2695- }
2696- int64_t layoutWarpsPerCTA = product (blocked.getWarpsPerCTA ());
2697- if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2698- return makeErr () << layout << " .\n Layout has a total of "
2699- << layoutWarpsPerCTA
2700- << " warps per CTA, but the context requires "
2701- << *moduleWarpsPerCTA << " warps per CTA." ;
2702- }
2703-
2704- if (blocked.getCTALayout ().getCTAsPerCGA ().size () > 0 ) {
2705- int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs (module );
2706- int64_t layoutCTAsPerCGA =
2707- product (blocked.getCTALayout ().getCTAsPerCGA ());
2708- if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
2709- return makeErr () << layout << " .\n Layout has a total of "
2710- << layoutCTAsPerCGA
2711- << " CTAs per CGA, but the module specifies "
2712- << moduleCTAsPerCGA << " CTAs per CGA." ;
2713- }
2714- }
2682+ auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2683+ if (!distr)
2684+ return makeErr ()
2685+ << " Non-distributed layout is not allowed in tensor type." ;
2686+ auto rank = distr.getRepOrder ().size ();
2687+ if (rank != rankedTy.getRank ())
2688+ return makeErr () << " Layout has rank " << rank
2689+ << " , but the tensor it's attached to has rank "
2690+ << rankedTy.getRank () << " ." ;
2691+ if (llvm::any_of (rankedTy.getShape (),
2692+ [](int64_t i) { return !llvm::isPowerOf2_64 (i); })) {
2693+ return makeErr () << " Layout has shape " << rankedTy.getShape ()
2694+ << " , but the tensor it's attached to has shape "
2695+ << rankedTy.getShape ()
2696+ << " which is not a power of two." ;
2697+ }
2698+ auto ll = toLinearLayout (rankedTy.getShape (), layout);
2699+ ModuleOp module = op->getParentOfType <ModuleOp>();
2700+
2701+ // Number of threads per warp.
2702+ auto kLane = StringAttr::get (module .getContext (), " lane" );
2703+ int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2704+ if (ll.getInDimSize (kLane ) != moduleThreadsPerWarp) {
2705+ return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kLane )
2706+ << " threads per warp, but the module specifies "
2707+ << moduleThreadsPerWarp << " threads per warp." ;
2708+ }
2709+
2710+ // Number of warps per CTA.
2711+ std::optional<int > moduleWarpsPerCTA = maybeLookupNumWarps (op);
2712+ if (!moduleWarpsPerCTA) {
2713+ return makeErr ()
2714+ << " Could not determine the number of warps per CTA. Operation "
2715+ " is not in a context with `ttg.num-warps`." ;
2716+ }
2717+ auto kWarp = StringAttr::get (module .getContext (), " warp" );
2718+ if (ll.getInDimSize (kWarp ) != *moduleWarpsPerCTA) {
2719+ return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kWarp )
2720+ << " warps per CTA, but the context requires "
2721+ << *moduleWarpsPerCTA << " warps per CTA." ;
2722+ }
2723+
2724+ // Number of CTAs per CGA.
2725+ auto kBlock = StringAttr::get (module .getContext (), " block" );
2726+ int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs (module );
2727+ if (ll.getInDimSize (kBlock ) != moduleCTAsPerCGA) {
2728+ return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kBlock )
2729+ << " CTAs per CGA, but the context requires "
2730+ << moduleCTAsPerCGA << " CTAs per CGA." ;
27152731 }
2716-
27172732 return success ();
27182733 }
27192734};
0 commit comments