Skip to content

Commit 0b32613

Browse files
whitneywhtsanganmyachev
authored andcommitted
[intel] Small fixes for dot operand properties
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 1f219cd commit 0b32613

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ along the row (resp. col) dimension.
8282
SmallVector<unsigned> getShapeB() const;
8383
SmallVector<unsigned> getShapeC() const;
8484
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
85-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
85+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
8686
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
87-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
88-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
87+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
88+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
8989

9090
bool supportReduction() const {
9191
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
194194
}
195195
}
196196

197-
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
197+
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
198198
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
199199
auto shapePerCTA = getShapePerCTA(*this, shape);
200200
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
@@ -234,8 +234,8 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
234234
}
235235

236236
SmallVector<unsigned>
237-
DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
238-
int opIdx) const {
237+
DpasEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
238+
int kWidth, int opIdx) const {
239239
auto parentShapePerCTATile = getShapePerCTATile(shape);
240240
auto threadsPerWarp = getThreadsPerWarp();
241241
if (opIdx == 0) {
@@ -250,7 +250,7 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
250250
}
251251

252252
SmallVector<unsigned>
253-
DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
253+
DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
254254
if (opIdx == 0) {
255255
SmallVector<unsigned> shapeA = getDPASInstShapeA();
256256
unsigned subGroupSize = getSubGroupSize();
@@ -290,7 +290,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
290290

291291
SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
292292
ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
293-
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperands(opIdx);
293+
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperand(0, opIdx);
294294
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);
295295

296296
return {static_cast<unsigned>(sizePerThread[0] * repetitions[0]),

0 commit comments

Comments
 (0)