88#include " mlir/Support/LLVM.h"
99#include " triton/Analysis/Utility.h"
1010#include " triton/Dialect/Triton/IR/Utility.h"
11+ #include " triton/Dialect/TritonGPU/IR/Attributes.h"
1112#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1213#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1314#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -234,8 +235,31 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
234235 return resOrder;
235236}
236237
238+ SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank,
239+ bool kMajor ) {
240+ // kMajor: if true, the matrix is fastest-running on k,
241+ // otherwise it is on m (resp. n)
242+ // opIdx=0: [batch, m, k] if rank == 3 else [m, k]
243+ // opIdx=1: [batch, k, n] if rank == 3 else [k, n]
244+ // batch (if rank == 3) is always the slowest running dimension
245+ assert (rank == 2 || rank == 3 );
246+ assert (opIdx == 0 || opIdx == 1 );
247+ SmallVector<unsigned > order (rank);
248+ std::iota (order.rbegin (), order.rend (), 0 );
249+ // If opIdx is 1 and kMajor is true, the order is [0, 1]
250+ // (resp. [1, 2, 0] if rank == 3)
251+ // Same if opIdx is 0 and kMajor is false
252+ if (bool (opIdx) == kMajor ) {
253+ std::swap (order[0 ], order[1 ]);
254+ }
255+ return order;
256+ }
257+
237258SmallVector<unsigned > getWarpOrder (Attribute layout) {
238259 auto order = getOrder (layout);
260+ // FIXME: This mmaLayout if should just return
261+ // getOrderForDotOperand(0, order.size(), kMajor=false)
262+ // as mma has the same order as DotOperand(opIdx=0)
239263 if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
240264 if (mmaLayout.isHopper ()) {
241265 // Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -245,40 +269,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
245269 order.insert (order.begin (), 0 );
246270 }
247271 } else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
248- // opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0]
249- // opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0]
250- std::iota (order.rbegin (), order.rend (), 0 );
251- if (dotOpLayout.getOpIdx () == 0 ) {
252- std::swap (order[0 ], order[1 ]);
253- }
254- }
255- return order;
256- }
257-
258- SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank) {
259- assert ((rank == 2 || rank == 3 ) &&
260- " Invalid rank for dot operand order computation" );
261- SmallVector<unsigned > order (rank);
262- // The 'order' field typically represents a descending sorted array of
263- // dimensions based on contiguity. For instance, in axisInfo utilities that
264- // retrieve tensor contiguity, it's assumed that the dimension with the
265- // highest contiguity corresponds to order[0].
266- //
267- // The relation between contiguity and order is only relevant if the layout
268- // interfaces with HBM, as is the case when we load tensor from HBM to
269- // registers in the dot layout to bypass LDS. When bypassing LDS, we make
270- // the following assumptions about tensor layouts:
271- // - Tensor A (opIdx == 0) is considered to be row-major.
272- // - Tensor B (opIdx == 1) is considered to be column-major.
273- //
274- // Based on these assumptions, we define the following orders:
275- // - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2,
276- // 1, 0] for 3D tensors.
277- // - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1,
278- // 2, 0] for 3D tensors.
279- std::iota (order.rbegin (), order.rend (), 0 );
280- if (opIdx == 1 ) {
281- std::swap (order[0 ], order[1 ]);
272+ order = getOrderForDotOperand (dotOpLayout.getOpIdx (), order.size (),
273+ /* kMajor*/ false );
282274 }
283275 return order;
284276}
@@ -295,8 +287,8 @@ SmallVector<unsigned> getOrder(Attribute layout) {
295287 return order;
296288 }
297289 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
298- auto rank = getWarpsPerCTA ( dotLayout.getParent () ).size ();
299- return getOrderForDotOperand (dotLayout.getOpIdx (), rank);
290+ auto rank = dotLayout.getWarpsPerCTA ( ).size ();
291+ return getOrderForDotOperand (dotLayout.getOpIdx (), rank, /* kMajor */ true );
300292 }
301293 if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
302294 SmallVector<unsigned > parentOrder = getOrder (sliceLayout.getParent ());
@@ -1048,7 +1040,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10481040 return ::getWarpOrder (*this );
10491041}
10501042SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
1051- return ::getOrder (*this );
1043+ return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
1044+ /* kMajor*/ true );
10521045}
10531046SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
10541047 ArrayRef<int64_t > tensorShape) const {
@@ -2019,6 +2012,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
20192012 ArrayRef<int64_t > shape, int bitwidth, int kWidth , int opIdx) const {
20202013 auto rank = shape.size ();
20212014 auto warpsPerCTA = getWarpsPerCTA ();
2015+
20222016 SmallVector<int > shapePerWarp = {1 , 16 , 8 , 4 * 64 / bitwidth};
20232017 int numRepBatch =
20242018 rank == 3
0 commit comments