Skip to content

Commit 4b7942a

Browse files
whitneywhtsanganmyachev
authored andcommitted
[intel] Small fixes for dot operand properties
Signed-off-by: Whitney Tsang <[email protected]>
1 parent e1460b1 commit 4b7942a

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ along the row (resp. col) dimension.
8282
SmallVector<unsigned> getShapeB() const;
8383
SmallVector<unsigned> getShapeC() const;
8484
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
85-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
85+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
8686
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
87-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
88-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
87+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
88+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
8989

9090
bool supportReduction() const {
9191
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
247247
warpsPerCTA[rank - 1]))};
248248
}
249249

250-
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
250+
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
251251
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
252252
auto shapePerCTA = getShapePerCTA(*this, shape);
253253
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
@@ -298,8 +298,8 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
298298
}
299299

300300
SmallVector<unsigned>
301-
DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
302-
int opIdx) const {
301+
DpasEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
302+
int kWidth, int opIdx) const {
303303
auto parentShapePerCTATile = getShapePerCTATile(shape);
304304
size_t rank = parentShapePerCTATile.size();
305305
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
@@ -325,7 +325,7 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
325325
}
326326

327327
SmallVector<unsigned>
328-
DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
328+
DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
329329
ArrayRef<unsigned> repCluster = getRepCluster();
330330
size_t rank = repCluster.size();
331331
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
@@ -367,7 +367,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
367367

368368
SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
369369
ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
370-
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperands(opIdx);
370+
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperand(0, opIdx);
371371
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);
372372

373373
size_t rank = shape.size();

0 commit comments

Comments
 (0)