Skip to content

Commit f1a893a

Browse files
authored
[CommonCodeClean] Remove getElemsPerThreadForOperands (#2916)
Remove `getElemsPerThreadForOperands`
1 parent c5f5ac1 commit f1a893a

File tree

5 files changed

+18
-48
lines changed

5 files changed

+18
-48
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -793,11 +793,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
793793
"SmallVector<unsigned>",
794794
"getRepOrderForOperand",
795795
(ins "int":$opIdx)>,
796-
797-
InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
798-
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
799-
"Type":$eltTy,
800-
"unsigned":$opIdx)>,
801796
];
802797
}
803798

@@ -931,10 +926,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
931926
return contigPerThread;
932927
};
933928

934-
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
935-
llvm_unreachable("getElemsPerThreadForOperands is not supported.");
936-
};
937-
938929
}];
939930

940931
let genVerifyDecl = 1;
@@ -1043,10 +1034,6 @@ Row | warp 0 warp 2
10431034
}
10441035
return contigPerThread;
10451036
};
1046-
1047-
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
1048-
llvm_unreachable("getElemsPerThreadForOperands is not supported.");
1049-
};
10501037
}];
10511038
}
10521039

@@ -1171,10 +1158,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11711158
return contigPerThread;
11721159
};
11731160

1174-
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
1175-
llvm_unreachable("getElemsPerThreadForOperands is not supported.");
1176-
};
1177-
11781161
}];
11791162

11801163
let hasCustomAssemblyFormat = 1;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ unsigned getTotalElemsPerThread(Type type) {
7070
if (type.isIntOrIndexOrFloat() || isa<triton::PointerType>(type))
7171
return 1;
7272
auto tensorType = cast<RankedTensorType>(type);
73+
74+
std::optional<LinearLayout> ll = triton::gpu::toLinearLayout(
75+
tensorType.getShape(), tensorType.getEncoding());
76+
if (ll.has_value()) {
77+
MLIRContext *ctx = tensorType.getContext();
78+
auto kRegister = StringAttr::get(ctx, "register");
79+
return ll->getInDimSize(kRegister);
80+
}
81+
// fallback to legacy layout interface.
7382
return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
7483
tensorType.getElementType());
7584
}
@@ -1065,9 +1074,6 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
10651074
return regs;
10661075
}
10671076

1068-
if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
1069-
return mmaParent.getElemsPerThreadForOperands(shape, eltTy, getOpIdx());
1070-
}
10711077
llvm_unreachable("getElemsPerThread is not supported for dot operand");
10721078
return SmallVector<unsigned>();
10731079
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -736,15 +736,15 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
736736
LinearLayout ret =
737737
LinearLayout(std::move(bases), llvm::to_vector(sliceLL.getOutDimNames()));
738738

739-
// Match a hack in the legacy code that ensures that the number of registers
740-
// matches getTotalElemsPerThread. Yup: We just removed all the zeros, now
741-
// we're (maybe) adding some back. :)
742-
//
743-
// TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing
744-
// legacy code, I think we can remove this.
745-
int expectedNumRegisters =
746-
triton::gpu::getTotalElemsPerThread(RankedTensorType::get(
747-
shape, IntegerType::get(ctx, 32) /*dummy type*/, *this));
739+
// The triton generate the homogeneous kernel run on every thread.
740+
// The multiple threads of the parent layout which are distributed on the
741+
// sliced dim are squeezed to hold the same value of tensor redundantly. The
742+
// multiple values of sizePerThreads[dim] of the parent are reduced to the
743+
// only one. We need to fix up the number of registers in case we just removed
744+
// all zeros aggressively.
745+
auto sizePerThreads = triton::gpu::getSizePerThread(getParent());
746+
unsigned expectedNumRegisters =
747+
parentLL->getInDimSize(S("register")) / sizePerThreads[getDim()];
748748
if (ret.getInDimSize(S("register")) != expectedNumRegisters) {
749749
int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register"));
750750
// Our use of "dim0" here is arbitrary; because we're adding zeros, any

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,6 @@ along the row (resp. col) dimension.
100100
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
101101
return getSizePerThreadForOperand(kWidth, static_cast<OpIdx>(opIdx));
102102
}
103-
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
104-
return getElemsPerThreadForOperands(shape, eltTy, static_cast<OpIdx>(opIdx));
105-
}
106103
SmallVector<unsigned> getRepOrderForOperand(unsigned opIdx) const {
107104
return getRepOrderForOperand(static_cast<OpIdx>(opIdx));
108105
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,22 +347,6 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const {
347347
llvm_unreachable("unexpected opIdx");
348348
}
349349

350-
SmallVector<unsigned>
351-
DpasEncodingAttr::getElemsPerThreadForOperands(ArrayRef<int64_t> shape,
352-
Type eltTy, OpIdx opIdx) const {
353-
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperand(0, opIdx);
354-
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);
355-
356-
size_t rank = shape.size();
357-
SmallVector<unsigned> elemsPerThread(rank);
358-
if (rank == 3)
359-
elemsPerThread[0] = repetitions[0];
360-
elemsPerThread[rank - 2] = sizePerThread[0] * repetitions[1];
361-
elemsPerThread[rank - 1] = sizePerThread[1] * repetitions[2];
362-
363-
return elemsPerThread;
364-
};
365-
366350
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
367351
size_t rank = getWarpsPerCTA().size();
368352
assert(rank == 2 || rank == 3);

0 commit comments

Comments
 (0)