1111#include " mlir/Support/LLVM.h"
1212#include " triton/Analysis/Utility.h"
1313#include " triton/Dialect/Triton/IR/Utility.h"
14+ #include " triton/Dialect/TritonGPU/IR/Attributes.h"
1415#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1516#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1617#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -237,8 +238,31 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
237238 return resOrder;
238239}
239240
241+ SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank,
242+ bool kMajor ) {
243+ // kMajor: if true, the matrix is fastest-running on k,
244+ // otherwise it is on m (resp. n)
245+ // opIdx=0: [batch, m, k] if rank == 3 else [m, k]
246+ // opIdx=1: [batch, k, n] if rank == 3 else [k, n]
247+ // batch (if rank == 3) is always the slowest running dimension
248+ assert (rank == 2 || rank == 3 );
249+ assert (opIdx == 0 || opIdx == 1 );
250+ SmallVector<unsigned > order (rank);
251+ std::iota (order.rbegin (), order.rend (), 0 );
252+ // If opIdx is 1 and kMajor is true, the order is [0, 1]
253+ // (resp. [1, 2, 0] if rank == 3)
254+ // Same if opIdx is 0 and kMajor is false
255+ if (bool (opIdx) == kMajor ) {
256+ std::swap (order[0 ], order[1 ]);
257+ }
258+ return order;
259+ }
260+
240261SmallVector<unsigned > getWarpOrder (Attribute layout) {
241262 auto order = getOrder (layout);
263+ // FIXME: This mmaLayout if should just return
264+ // getOrderForDotOperand(0, order.size(), kMajor=false)
265+ // as mma has the same order as DotOperand(opIdx=0)
242266 if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
243267 if (mmaLayout.isHopper ()) {
244268 // Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -247,30 +271,9 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
247271 order.erase (it);
248272 order.insert (order.begin (), 0 );
249273 }
250- }
251- return order;
252- }
253-
254- SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank) {
255- SmallVector<unsigned > order (rank);
256- // The 'order' field typically represents a descending sorted array of
257- // dimensions based on contiguity. For instance, in axisInfo utilities that
258- // retrieve tensor contiguity, it's assumed that the dimension with the
259- // highest contiguity corresponds to order[0].
260- //
261- // The relation between contiguity and order is only relevant if the layout
262- // interfaces with HBM, as is the case when we load tensor from HBM to
263- // registers in the dot layout to bypass LDS. When bypassing LDS, we make the
264- // following assumptions about tensor layouts:
265- // - Tensor A (opIdx == 0) is considered to be row-major.
266- // - Tensor B (opIdx == 1) is considered to be column-major.
267- //
268- // Based on these assumptions, we define the following orders:
269- // - For opIdx == 0, we assume an order of [1, 0].
270- // - For opIdx == 1, we assume an order of [0, 1].
271- std::iota (order.rbegin (), order.rend (), 0 );
272- if (opIdx == 1 ) {
273- std::swap (order[0 ], order[1 ]);
274+ } else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
275+ order = getOrderForDotOperand (dotOpLayout.getOpIdx (), order.size (),
276+ /* kMajor*/ false );
274277 }
275278 return order;
276279}
@@ -287,13 +290,12 @@ SmallVector<unsigned> getOrder(Attribute layout) {
287290 return order;
288291 }
289292 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
290- auto rank = getWarpsPerCTA (dotLayout.getParent ()).size ();
291- SmallVector<unsigned > order (rank);
293+ auto rank = dotLayout.getWarpsPerCTA ().size ();
292294 if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent ())) {
293- return getOrderForDotOperand (dotLayout.getOpIdx (), rank);
294- } else {
295- std::iota (order.rbegin (), order.rend (), 0 );
295+ return getOrderForDotOperand (dotLayout.getOpIdx (), rank, /* kMajor*/ true );
296296 }
297+ SmallVector<unsigned > order (rank);
298+ std::iota (order.rbegin (), order.rend (), 0 );
297299 return order;
298300 }
299301 if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1059,7 +1061,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591061 return ::getWarpOrder (*this );
10601062}
10611063SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
1062- return ::getOrder (*this );
1064+ return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
1065+ /* kMajor*/ true );
10631066}
10641067SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
10651068 ArrayRef<int64_t > tensorShape) const {
@@ -2042,6 +2045,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
20422045 int opIdx) const {
20432046 auto rank = shape.size ();
20442047 auto warpsPerCTA = getWarpsPerCTA ();
2048+
20452049 SmallVector<int > shapePerWarp = {1 , 16 , 8 , 4 * 64 / bitwidth};
20462050 int numRepBatch =
20472051 rank == 3
0 commit comments