1111#include " mlir/Support/LLVM.h"
1212#include " triton/Analysis/Utility.h"
1313#include " triton/Dialect/Triton/IR/Utility.h"
14- #include " triton/Dialect/TritonGPU/IR/Attributes.h"
1514#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1615#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1716#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -238,31 +237,8 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
238237 return resOrder;
239238}
240239
241- SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank,
242- bool kMajor ) {
243- // kMajor: if true, the matrix is fastest-running on k,
244- // otherwise it is on m (resp. n)
245- // opIdx=0: [batch, m, k] if rank == 3 else [m, k]
246- // opIdx=1: [batch, k, n] if rank == 3 else [k, n]
247- // batch (if rank == 3) is always the slowest running dimension
248- assert (rank == 2 || rank == 3 );
249- assert (opIdx == 0 || opIdx == 1 );
250- SmallVector<unsigned > order (rank);
251- std::iota (order.rbegin (), order.rend (), 0 );
252- // If opIdx is 1 and kMajor is true, the order is [0, 1]
253- // (resp. [1, 2, 0] if rank == 3)
254- // Same if opIdx is 0 and kMajor is false
255- if (bool (opIdx) == kMajor ) {
256- std::swap (order[0 ], order[1 ]);
257- }
258- return order;
259- }
260-
261240SmallVector<unsigned > getWarpOrder (Attribute layout) {
262241 auto order = getOrder (layout);
263- // FIXME: This mmaLayout if should just return
264- // getOrderForDotOperand(0, order.size(), kMajor=false)
265- // as mma has the same order as DotOperand(opIdx=0)
266242 if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
267243 if (mmaLayout.isHopper ()) {
268244 // Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -271,9 +247,30 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
271247 order.erase (it);
272248 order.insert (order.begin (), 0 );
273249 }
274- } else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
275- order = getOrderForDotOperand (dotOpLayout.getOpIdx (), order.size (),
276- /* kMajor*/ false );
250+ }
251+ return order;
252+ }
253+
254+ SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank) {
255+ SmallVector<unsigned > order (rank);
256+ // The 'order' field typically represents a descending sorted array of
257+ // dimensions based on contiguity. For instance, in axisInfo utilities that
258+ // retrieve tensor contiguity, it's assumed that the dimension with the
259+ // highest contiguity corresponds to order[0].
260+ //
261+ // The relation between contiguity and order is only relevant if the layout
262+ // interfaces with HBM, as is the case when we load tensor from HBM to
263+ // registers in the dot layout to bypass LDS. When bypassing LDS, we make the
264+ // following assumptions about tensor layouts:
265+ // - Tensor A (opIdx == 0) is considered to be row-major.
266+ // - Tensor B (opIdx == 1) is considered to be column-major.
267+ //
268+ // Based on these assumptions, we define the following orders:
269+ // - For opIdx == 0, we assume an order of [1, 0].
270+ // - For opIdx == 1, we assume an order of [0, 1].
271+ std::iota (order.rbegin (), order.rend (), 0 );
272+ if (opIdx == 1 ) {
273+ std::swap (order[0 ], order[1 ]);
277274 }
278275 return order;
279276}
@@ -290,12 +287,13 @@ SmallVector<unsigned> getOrder(Attribute layout) {
290287 return order;
291288 }
292289 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
293- auto rank = dotLayout.getWarpsPerCTA ().size ();
290+ auto rank = getWarpsPerCTA (dotLayout.getParent ()).size ();
291+ SmallVector<unsigned > order (rank);
294292 if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent ())) {
295- return getOrderForDotOperand (dotLayout.getOpIdx (), rank, /* kMajor*/ true );
293+ return getOrderForDotOperand (dotLayout.getOpIdx (), rank);
294+ } else {
295+ std::iota (order.rbegin (), order.rend (), 0 );
296296 }
297- SmallVector<unsigned > order (rank);
298- std::iota (order.rbegin (), order.rend (), 0 );
299297 return order;
300298 }
301299 if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1061,8 +1059,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10611059 return ::getWarpOrder (*this );
10621060}
10631061SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
1064- return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
1065- /* kMajor*/ true );
1062+ return ::getOrder (*this );
10661063}
10671064SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
10681065 ArrayRef<int64_t > tensorShape) const {
@@ -2045,7 +2042,6 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
20452042 int opIdx) const {
20462043 auto rank = shape.size ();
20472044 auto warpsPerCTA = getWarpsPerCTA ();
2048-
20492045 SmallVector<int > shapePerWarp = {1 , 16 , 8 , 4 * 64 / bitwidth};
20502046 int numRepBatch =
20512047 rank == 3
0 commit comments