Skip to content

Commit a7ba67a

Browse files
committed
Revert "introduce 'getOrderForDotOperand' interface function"
This reverts commit be7965c.
1 parent be7965c commit a7ba67a

File tree

4 files changed

+4
-37
lines changed

4 files changed

+4
-37
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,10 +804,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
804804
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
805805
"Type":$eltTy,
806806
"unsigned":$opIdx)>,
807-
808-
InterfaceMethod<"Return order of dimensions for dot operands.", "SmallVector<unsigned>",
809-
"getOrderForDotOperand", (ins "unsigned":$opIdx,
810-
"unsigned": $rank)>,
811807
];
812808
}
813809

@@ -930,7 +926,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
930926
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
931927
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
932928
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
933-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
934929

935930
SmallVector<unsigned> getContigPerThread() {
936931
auto rank = getWarpsPerCTA().size();
@@ -1043,7 +1038,6 @@ Row | warp 0 warp 2
10431038
SmallVector<int64_t> getElemsPerInstrForOperands() const;
10441039
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
10451040
Type elemType, int kWidth, int opIdx) const;
1046-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
10471041
static SmallVector<unsigned> getMNKDimPerInstr();
10481042

10491043
SmallVector<unsigned> getContigPerThread() {
@@ -1255,7 +1249,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12551249
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
12561250
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
12571251
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1258-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
12591252

12601253
SmallVector<unsigned> getContigPerThread() {
12611254
assert(isVolta() || isAmpere() || isHopper());

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
304304
}
305305
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
306306
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
307-
if (auto mmaParent = dyn_cast<MmaEncodingTrait>(dotLayout.getParent())) {
308-
return mmaParent.getOrderForDotOperand(dotLayout.getOpIdx(), rank);
307+
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
308+
SmallVector<unsigned> order(rank);
309+
std::iota(order.rbegin(), order.rend(), 0);
310+
return order;
309311
}
310-
// This branch had to be left because not all types
311-
// inherit `MmaEncodingTrait` interface, for example `BlockedEncodingAttr`.
312312
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
313313
}
314314
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -842,12 +842,6 @@ unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
842842
return product<unsigned>(getElemsPerThread(shape, eltTy));
843843
}
844844

845-
SmallVector<unsigned>
846-
AMDMfmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
847-
unsigned rank) const {
848-
return ::getOrderForDotOperand(opIdx, rank);
849-
}
850-
851845
// Wmma encoding
852846

853847
SmallVector<unsigned>
@@ -877,12 +871,6 @@ unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
877871
return product<unsigned>(getElemsPerThread(shape, eltTy));
878872
}
879873

880-
SmallVector<unsigned>
881-
AMDWmmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
882-
unsigned rank) const {
883-
return ::getOrderForDotOperand(opIdx, rank);
884-
}
885-
886874
SmallVector<unsigned>
887875
NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
888876
Type eltTy) const {
@@ -970,12 +958,6 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
970958
return product<unsigned>(getElemsPerThread(shape, eltTy));
971959
}
972960

973-
SmallVector<unsigned>
974-
NvidiaMmaEncodingAttr::getOrderForDotOperand(unsigned opIdx,
975-
unsigned rank) const {
976-
return ::getOrderForDotOperand(opIdx, rank);
977-
}
978-
979961
//
980962

981963
SmallVector<unsigned>

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ along the row (resp. col) dimension.
8484
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
8585
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
8686
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
87-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) const;
8887
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
8988
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
9089

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,6 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
380380
return elemsPerThread;
381381
};
382382

383-
SmallVector<unsigned>
384-
DpasEncodingAttr::getOrderForDotOperand(unsigned opIdx, unsigned rank) const {
385-
SmallVector<unsigned> order(rank);
386-
std::iota(order.rbegin(), order.rend(), 0);
387-
return order;
388-
}
389-
390383
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
391384
size_t rank = getWarpsPerCTA().size();
392385
assert(rank == 2 || rank == 3);

0 commit comments

Comments
 (0)