1111#include " mlir/Support/LLVM.h"
1212#include " triton/Analysis/Utility.h"
1313#include " triton/Dialect/Triton/IR/Utility.h"
14+ #include " triton/Dialect/TritonGPU/IR/Attributes.h"
1415#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1516#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1617#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -237,13 +238,36 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
237238 return resOrder;
238239}
239240
241+ SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank,
242+ bool kMajor ) {
243+ // kMajor: if true, the matrix is fastest-running on k,
244+ // otherwise it is on m (resp. n)
245+ // opIdx=0: [batch, m, k] if rank == 3 else [m, k]
246+ // opIdx=1: [batch, k, n] if rank == 3 else [k, n]
247+ // batch (if rank == 3) is always the slowest running dimension
248+ assert (rank == 2 || rank == 3 );
249+ assert (opIdx == 0 || opIdx == 1 );
250+ SmallVector<unsigned > order (rank);
251+ std::iota (order.rbegin (), order.rend (), 0 );
252+ // If opIdx is 1 and kMajor is true, the order is [0, 1]
253+ // (resp. [1, 2, 0] if rank == 3)
254+ // Same if opIdx is 0 and kMajor is false
255+ if (bool (opIdx) == kMajor ) {
256+ std::swap (order[0 ], order[1 ]);
257+ }
258+ return order;
259+ }
260+
240261SmallVector<unsigned > getWarpOrder (Attribute layout) {
241262 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
242263 if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent ())) {
243264 return getWarpOrder (dotLayout.getParent ());
244265 }
245266 }
246267 auto order = getOrder (layout);
268+ // FIXME: This mmaLayout if should just return
269+ // getOrderForDotOperand(0, order.size(), kMajor=false)
270+ // as mma has the same order as DotOperand(opIdx=0)
247271 if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
248272 if (mmaLayout.isHopper ()) {
249273 // Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -253,40 +277,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
253277 order.insert (order.begin (), 0 );
254278 }
255279 } else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
256- // opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0]
257- // opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0]
258- std::iota (order.rbegin (), order.rend (), 0 );
259- if (dotOpLayout.getOpIdx () == 0 ) {
260- std::swap (order[0 ], order[1 ]);
261- }
262- }
263- return order;
264- }
265-
266- SmallVector<unsigned > getOrderForDotOperand (unsigned opIdx, unsigned rank) {
267- assert ((rank == 2 || rank == 3 ) &&
268- " Invalid rank for dot operand order computation" );
269- SmallVector<unsigned > order (rank);
270- // The 'order' field typically represents a descending sorted array of
271- // dimensions based on contiguity. For instance, in axisInfo utilities that
272- // retrieve tensor contiguity, it's assumed that the dimension with the
273- // highest contiguity corresponds to order[0].
274- //
275- // The relation between contiguity and order is only relevant if the layout
276- // interfaces with HBM, as is the case when we load tensor from HBM to
277- // registers in the dot layout to bypass LDS. When bypassing LDS, we make
278- // the following assumptions about tensor layouts:
279- // - Tensor A (opIdx == 0) is considered to be row-major.
280- // - Tensor B (opIdx == 1) is considered to be column-major.
281- //
282- // Based on these assumptions, we define the following orders:
283- // - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2,
284- // 1, 0] for 3D tensors.
285- // - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1,
286- // 2, 0] for 3D tensors.
287- std::iota (order.rbegin (), order.rend (), 0 );
288- if (opIdx == 1 ) {
289- std::swap (order[0 ], order[1 ]);
280+ order = getOrderForDotOperand (dotOpLayout.getOpIdx (), order.size (),
281+ /* kMajor*/ false );
290282 }
291283 return order;
292284}
@@ -303,7 +295,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
303295 return order;
304296 }
305297 if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
306- auto rank = getWarpsPerCTA ( dotLayout.getParent () ).size ();
298+ auto rank = dotLayout.getWarpsPerCTA ( ).size ();
307299 // FIXME: delete if branch for `DpasEncodingAttr` and provide more
308300 // general solution to make `getOrderForDotOperand` function compatible
309301 // with Intel layouts.
@@ -314,7 +306,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
314306 std::iota (order.rbegin (), order.rend (), 0 );
315307 return order;
316308 }
317- return getOrderForDotOperand (dotLayout.getOpIdx (), rank);
309+ return getOrderForDotOperand (dotLayout.getOpIdx (), rank, /* kMajor */ true );
318310 }
319311 if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
320312 SmallVector<unsigned > parentOrder = getOrder (sliceLayout.getParent ());
@@ -1069,7 +1061,17 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10691061 return ::getWarpOrder (*this );
10701062}
10711063SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
1072- return ::getOrder (*this );
1064+ // FIXME: delete if branch for `DpasEncodingAttr` and provide more
1065+ // general solution to make `getOrderForDotOperand` function compatible
1066+ // with Intel layouts.
1067+ // More details:
1068+ // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1069+ if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent ())) {
1070+ return ::getOrder (*this );
1071+ } else {
1072+ return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
1073+ /* kMajor*/ true );
1074+ }
10731075}
10741076SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
10751077 ArrayRef<int64_t > tensorShape) const {
@@ -2055,6 +2057,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
20552057 ArrayRef<int64_t > shape, int bitwidth, int kWidth , int opIdx) const {
20562058 auto rank = shape.size ();
20572059 auto warpsPerCTA = getWarpsPerCTA ();
2060+
20582061 SmallVector<int > shapePerWarp = {1 , 16 , 8 , 4 * 64 / bitwidth};
20592062 int numRepBatch =
20602063 rank == 3
0 commit comments