@@ -143,8 +143,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
143143 // Return the order that represents that the batch is in row-major or
144144 // column-major order for a batch of matrices of shape [*, m, n] with
145145 // len(shape) == rank.
146- assert (rank >= 2 );
147146 SmallVector<unsigned > order (rank);
147+ if (rank < 2 ) {
148+ return order;
149+ }
148150 std::iota (order.rbegin (), order.rend (), 0 );
149151 if (!rowMajor) {
150152 std::swap (order[0 ], order[1 ]);
@@ -397,6 +399,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
397399 return emitError () << " sizePerThread, threadsPerWarp, warpsPerCTA, and "
398400 " order must all have the same rank." ;
399401 }
402+ if (llvm::any_of (sizePerThread,
403+ [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
404+ return emitError ()
405+ << " Every element in sizePerThread must be a power of two." ;
406+ }
407+ if (llvm::any_of (threadsPerWarp,
408+ [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
409+ return emitError ()
410+ << " Every element in threadsPerWarp must be a power of two." ;
411+ }
412+ if (llvm::any_of (warpsPerCTA,
413+ [](unsigned x) { return !llvm::isPowerOf2_64 (x); })) {
414+ return emitError ()
415+ << " Every element in warpsPerCTA must be a power of two." ;
416+ }
400417
401418 // Empty CTALayout is allowed, but if it's present its rank must match the
402419 // BlockedEncodingAttr's rank.
@@ -1996,6 +2013,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19962013SmallVector<unsigned > DotOperandEncodingAttr::getRepOrder () const {
19972014 if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent ())) {
19982015 return mma.getRepOrderForOperand (getOpIdx ());
2016+ } else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
2017+ return to_vector (blocked.getOrder ());
19992018 }
20002019 llvm::report_fatal_error (
20012020 " getRepOrder not implemented for DotOperandEncodingAttr" );
@@ -2696,60 +2715,56 @@ struct TritonGPUVerifyTensorLayoutInterface
26962715 LogicalResult verifyTensorLayout (
26972716 Attribute layout, RankedTensorType rankedTy, Operation *op,
26982717 function_ref<InFlightDiagnostic()> makeErr) const override {
2699- if (isa<triton::gpu::SharedEncodingTrait>(layout))
2700- return makeErr () << " Shared layout is not allowed on tensor type." ;
2701- // TODO(jlebar): Currently this only checks blocked layouts, but other
2702- // layouts also have invariants!
2703-
2704- // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2705- if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2706- ModuleOp module = op->getParentOfType <ModuleOp>();
2707-
2708- // A different verifier should have checked that the layout itself is
2709- // valid, including that threads-per-warp has the same rank as
2710- // warps-per-block etc.
2711- if (blocked.getRank () != rankedTy.getRank ()) {
2712- return makeErr () << layout << " .\n Layout has rank " << blocked.getRank ()
2713- << " , but the tensor it's attached to has rank "
2714- << rankedTy.getRank () << " ." ;
2715- }
2716-
2717- int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2718- int64_t layoutThreadsPerWarp = product (blocked.getThreadsPerWarp ());
2719- if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2720- return makeErr () << layout << " .\n Layout has a total of "
2721- << layoutThreadsPerWarp
2722- << " threads per warp, but the module specifies "
2723- << moduleThreadsPerWarp << " threads per warp." ;
2724- }
2725-
2726- std::optional<int > moduleWarpsPerCTA = maybeLookupNumWarps (op);
2727- if (!moduleWarpsPerCTA) {
2728- return makeErr ()
2729- << " Could not determine the number of warps per CTA. Operation "
2730- " is not in a context with `ttg.num-warps`." ;
2731- }
2732- int64_t layoutWarpsPerCTA = product (blocked.getWarpsPerCTA ());
2733- if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2734- return makeErr () << layout << " .\n Layout has a total of "
2735- << layoutWarpsPerCTA
2736- << " warps per CTA, but the context requires "
2737- << *moduleWarpsPerCTA << " warps per CTA." ;
2738- }
2739-
2740- if (blocked.getCTALayout ().getCTAsPerCGA ().size () > 0 ) {
2741- int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs (module );
2742- int64_t layoutCTAsPerCGA =
2743- product (blocked.getCTALayout ().getCTAsPerCGA ());
2744- if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
2745- return makeErr () << layout << " .\n Layout has a total of "
2746- << layoutCTAsPerCGA
2747- << " CTAs per CGA, but the module specifies "
2748- << moduleCTAsPerCGA << " CTAs per CGA." ;
2749- }
2750- }
2718+ auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2719+ if (!distr)
2720+ return makeErr ()
2721+ << " Non-distributed layout is not allowed in tensor type." ;
2722+ auto rank = distr.getRepOrder ().size ();
2723+ if (rank != rankedTy.getRank ())
2724+ return makeErr () << " Layout has rank " << rank
2725+ << " , but the tensor it's attached to has rank "
2726+ << rankedTy.getRank () << " ." ;
2727+ if (llvm::any_of (rankedTy.getShape (),
2728+ [](int64_t i) { return !llvm::isPowerOf2_64 (i); })) {
2729+ return makeErr () << " Layout has shape " << rankedTy.getShape ()
2730+ << " , but the tensor it's attached to has shape "
2731+ << rankedTy.getShape ()
2732+ << " which is not a power of two." ;
2733+ }
2734+ auto ll = toLinearLayout (rankedTy.getShape (), layout);
2735+ ModuleOp module = op->getParentOfType <ModuleOp>();
2736+
2737+ // Number of threads per warp.
2738+ auto kLane = StringAttr::get (module .getContext (), " lane" );
2739+ int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp (module );
2740+ if (ll.getInDimSize (kLane ) != moduleThreadsPerWarp) {
2741+ return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kLane )
2742+ << " threads per warp, but the module specifies "
2743+ << moduleThreadsPerWarp << " threads per warp." ;
2744+ }
2745+
2746+ // Number of warps per CTA.
2747+ std::optional<int > moduleWarpsPerCTA = maybeLookupNumWarps (op);
2748+ if (!moduleWarpsPerCTA) {
2749+ return makeErr ()
2750+ << " Could not determine the number of warps per CTA. Operation "
2751+ " is not in a context with `ttg.num-warps`." ;
2752+ }
2753+ auto kWarp = StringAttr::get (module .getContext (), " warp" );
2754+ if (ll.getInDimSize (kWarp ) != *moduleWarpsPerCTA) {
2755+ return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kWarp )
2756+ << " warps per CTA, but the context requires "
2757+ << *moduleWarpsPerCTA << " warps per CTA." ;
2758+ }
2759+
2760+ // Number of CTAs per CGA.
2761+ auto kBlock = StringAttr::get (module .getContext (), " block" );
2762+ int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs (module );
2763+ if (ll.getInDimSize (kBlock ) != moduleCTAsPerCGA) {
2764+ return makeErr () << layout << " .\n Layout has " << ll.getInDimSize (kBlock )
2765+ << " CTAs per CGA, but the context requires "
2766+ << moduleCTAsPerCGA << " CTAs per CGA." ;
27512767 }
2752-
27532768 return success ();
27542769 }
27552770};
0 commit comments