Skip to content

Commit a9e6384

Browse files
committed
introduce 'getOrderForDotOperand' interface function
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent dfd526b commit a9e6384

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,10 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
804804
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
805805
"Type":$eltTy,
806806
"unsigned":$opIdx)>,
807+
808+
InterfaceMethod<"Return order of dimensions for dot operands.", "SmallVector<unsigned>",
809+
"getOrderForDotOperand", (ins "unsigned":$opIdx,
810+
"unsigned": $rank)>,
807811
];
808812
}
809813

@@ -926,6 +930,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
926930
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927931
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
928932
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
933+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
929934

930935
SmallVector<unsigned> getContigPerThread() {
931936
auto rank = getWarpsPerCTA().size();
@@ -1038,6 +1043,7 @@ Row | warp 0 warp 2
10381043
SmallVector<int64_t> getElemsPerInstrForOperands() const;
10391044
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
10401045
Type elemType, int kWidth, int opIdx) const;
1046+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
10411047
static SmallVector<unsigned> getMNKDimPerInstr();
10421048

10431049
SmallVector<unsigned> getContigPerThread() {
@@ -1249,6 +1255,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12491255
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
12501256
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
12511257
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1258+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
12521259

12531260
SmallVector<unsigned> getContigPerThread() {
12541261
assert(isVolta() || isAmpere() || isHopper());

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,8 @@ SmallVector<unsigned> getOrder(Attribute layout) {
304304
}
305305
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
306306
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
307-
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
308-
SmallVector<unsigned> order(rank);
309-
std::iota(order.rbegin(), order.rend(), 0);
310-
return order;
311-
}
312-
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
307+
auto mmaParent = dyn_cast<MmaEncodingTrait>(dotLayout.getParent());
308+
return mmaParent.getOrderForDotOperand(dotLayout.getOpIdx(), rank);
313309
}
314310
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
315311
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
@@ -842,6 +838,12 @@ unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
842838
return product<unsigned>(getElemsPerThread(shape, eltTy));
843839
}
844840

841+
SmallVector<unsigned>
842+
AMDMfmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
843+
unsigned rank) const {
844+
return ::getOrderForDotOperand(opIdx, rank);
845+
}
846+
845847
// Wmma encoding
846848

847849
SmallVector<unsigned>
@@ -871,6 +873,12 @@ unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
871873
return product<unsigned>(getElemsPerThread(shape, eltTy));
872874
}
873875

876+
SmallVector<unsigned>
877+
AMDWmmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
878+
unsigned rank) const {
879+
return ::getOrderForDotOperand(opIdx, rank);
880+
}
881+
874882
SmallVector<unsigned>
875883
NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
876884
Type eltTy) const {
@@ -958,6 +966,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
958966
return product<unsigned>(getElemsPerThread(shape, eltTy));
959967
}
960968

969+
SmallVector<unsigned>
970+
NvidiaMmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
971+
unsigned rank) const {
972+
return ::getOrderForDotOperand(opIdx, rank);
973+
}
974+
961975
//
962976

963977
SmallVector<unsigned>

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ along the row (resp. col) dimension.
8484
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
8585
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
8686
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
87+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
8788
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
8889
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
8990

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
297297
static_cast<unsigned>(sizePerThread[1] * repetitions[1])};
298298
};
299299

300+
SmallVector<unsigned>
301+
DpasEncodingAttr::getOrderForDotOperand(unsigned opIdx, unsigned rank) const {
302+
SmallVector<unsigned> order(rank);
303+
std::iota(order.rbegin(), order.rend(), 0);
304+
return order;
305+
}
306+
300307
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
301308
unsigned threadsPerWarp = getSubGroupSize();
302309
auto shapeC = getDPASInstShapeC();

0 commit comments

Comments
 (0)