Skip to content

Commit 4d53f8e

Browse files
anmyachevwhitneywhtsang
authored andcommitted
Introduce 'getOrderForDotOperand' interface function
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 4033a23 commit 4d53f8e

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,10 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
804804
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
805805
"Type":$eltTy,
806806
"unsigned":$opIdx)>,
807+
808+
InterfaceMethod<"Return order of dimensions for dot operands.", "SmallVector<unsigned>",
809+
"getOrderForDotOperand", (ins "unsigned":$opIdx,
810+
"unsigned": $rank)>,
807811
];
808812
}
809813

@@ -926,6 +930,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
926930
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927931
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
928932
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
933+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
929934

930935
SmallVector<unsigned> getContigPerThread() {
931936
auto rank = getWarpsPerCTA().size();
@@ -1038,6 +1043,7 @@ Row | warp 0 warp 2
10381043
SmallVector<int64_t> getElemsPerInstrForOperands() const;
10391044
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
10401045
Type elemType, int kWidth, int opIdx) const;
1046+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
10411047
static SmallVector<unsigned> getMNKDimPerInstr();
10421048

10431049
SmallVector<unsigned> getContigPerThread() {
@@ -1249,6 +1255,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12491255
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
12501256
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
12511257
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1258+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
12521259

12531260
SmallVector<unsigned> getContigPerThread() {
12541261
assert(isVolta() || isAmpere() || isHopper());

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
299299
}
300300
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
301301
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
302-
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
303-
SmallVector<unsigned> order(rank);
304-
std::iota(order.rbegin(), order.rend(), 0);
305-
return order;
302+
if (auto mmaParent = dyn_cast<MmaEncodingTrait>(dotLayout.getParent())) {
303+
return mmaParent.getOrderForDotOperand(dotLayout.getOpIdx(), rank);
306304
}
305+
// This branch had to be left because not all types
306+
// inherit `MmaEncodingTrait` interface, for example `BlockedEncodingAttr`.
307307
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
308308
}
309309
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -837,6 +837,12 @@ unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
837837
return product<unsigned>(getElemsPerThread(shape, eltTy));
838838
}
839839

840+
SmallVector<unsigned>
841+
AMDMfmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
842+
unsigned rank) const {
843+
return ::getOrderForDotOperand(opIdx, rank);
844+
}
845+
840846
// Wmma encoding
841847

842848
SmallVector<unsigned>
@@ -866,6 +872,12 @@ unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
866872
return product<unsigned>(getElemsPerThread(shape, eltTy));
867873
}
868874

875+
SmallVector<unsigned>
876+
AMDWmmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
877+
unsigned rank) const {
878+
return ::getOrderForDotOperand(opIdx, rank);
879+
}
880+
869881
SmallVector<unsigned>
870882
NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
871883
Type eltTy) const {
@@ -953,6 +965,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
953965
return product<unsigned>(getElemsPerThread(shape, eltTy));
954966
}
955967

968+
SmallVector<unsigned>
969+
NvidiaMmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
970+
unsigned rank) const {
971+
return ::getOrderForDotOperand(opIdx, rank);
972+
}
973+
956974
//
957975

958976
SmallVector<unsigned>

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ along the row (resp. col) dimension.
8484
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
8585
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
8686
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
87+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
8788
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
8889
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
8990

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
297297
static_cast<unsigned>(sizePerThread[1] * repetitions[1])};
298298
};
299299

300+
SmallVector<unsigned>
301+
DpasEncodingAttr::getOrderForDotOperand(unsigned opIdx, unsigned rank) const {
302+
SmallVector<unsigned> order(rank);
303+
std::iota(order.rbegin(), order.rend(), 0);
304+
return order;
305+
}
306+
300307
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
301308
unsigned threadsPerWarp = getSubGroupSize();
302309
auto shapeC = getDPASInstShapeC();

0 commit comments

Comments
 (0)