Skip to content

Commit 06eef9d

Browse files
committed
Fix index out of range
1 parent 19cbbbc commit 06eef9d

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

3+
#include <iostream>
34
#include <numeric>
45

56
#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
@@ -140,6 +141,8 @@ SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
140141
unsigned elemsPerThread = elemsNum / threadsPerWarp;
141142
auto repCluster = getRepCluster();
142143
// The Value is shard to lanes to threads per DPAS instruction.
144+
if (rank == 3)
145+
res[0] = repCluster[0];
143146
res[rank - 2] = elemsPerThread * repCluster[rank - 2];
144147
res[rank - 1] = repCluster[rank - 1];
145148
return res;
@@ -164,16 +167,25 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
164167
size_t rank = shape.size();
165168
assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout");
166169

167-
SmallVector<unsigned> elemsPerThread(rank);
170+
SmallVector<unsigned> elemsPerThread(rank, 1);
168171
auto shapePerCTATile = getShapePerCTATile(shape);
169172
unsigned tilesRow =
170173
ceil<unsigned>(shape[rank - 2], shapePerCTATile[rank - 2]);
171174
unsigned tilesCol =
172175
ceil<unsigned>(shape[rank - 1], shapePerCTATile[rank - 1]);
173176
auto sizePerThread = getSizePerThread();
177+
if (rank == 3)
178+
elemsPerThread[0] =
179+
sizePerThread[0] * ceil<unsigned>(shape[0], shapePerCTATile[0]);
174180
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * tilesRow;
175181
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * tilesCol;
176182

183+
// if (rank == 3)
184+
// std::cout << "elemsPerThread: " << elemsPerThread[0] << ", " <<
185+
// elemsPerThread[1] << ", " << elemsPerThread[2] << std::endl;
186+
// else
187+
// std::cout << "elemsPerThread: " << elemsPerThread[0] << ", " <<
188+
// elemsPerThread[1] << std::endl;
177189
return elemsPerThread;
178190
}
179191

@@ -382,14 +394,14 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
382394
SmallVector<unsigned> contigPerThread(rank, 1);
383395

384396
unsigned threadsPerWarp = getSubGroupSize();
385-
auto shapeC = getDPASInstShapeC();
397+
auto instShapeC = getDPASInstShapeC();
386398
// The software vectorization vectorized the value as C array: int a[N] -> int
387399
// a[N][threadsPerWarp]
388-
if (threadsPerWarp > shapeC[1]) {
400+
if (threadsPerWarp > instShapeC[1]) {
389401
return contigPerThread;
390-
} else if (threadsPerWarp == shapeC[1]) {
402+
} else if (threadsPerWarp == instShapeC[1]) {
391403
auto repCluster = getRepCluster();
392-
contigPerThread[rank - 2] = shapeC[0] * repCluster[rank - 2];
404+
contigPerThread[rank - 2] = instShapeC[0] * repCluster[rank - 2];
393405
return contigPerThread;
394406
} else {
395407
// threadsPerWarp < shapeC[1]

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "TargetInfo.h"
33
#include "Utility.h"
4+
#include <iostream>
45

56
#include "intel/include/Analysis/Utility.h"
67
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
@@ -115,12 +116,14 @@ struct ConvertLayoutOpConversion
115116
loc, rewriter, targetInfo, layout, type, false);
116117
SmallVector<SmallVector<unsigned>> offsets;
117118
::emitOffsetForDpasLayoutPerCTA(
118-
dpasLayout, offsets, multiDimCTAInRepId[0] * shapePerCTATile[0],
119-
multiDimCTAInRepId[1] * shapePerCTATile[1]);
119+
dpasLayout, offsets,
120+
multiDimCTAInRepId[rank - 2] * shapePerCTATile[rank - 2],
121+
multiDimCTAInRepId[rank - 1] * shapePerCTATile[rank - 1]);
120122

121123
SmallVector<Value> multiDimOffset(rank);
122124
if (rank == 3)
123-
multiDimOffset[0] = multiDimBase[0];
125+
multiDimOffset[0] = add(multiDimBase[0], i32_val(multiDimCTAInRepId[0] *
126+
shapePerCTATile[0]));
124127
multiDimOffset[rank - 2] =
125128
add(multiDimBase[rank - 2], i32_val(offsets[elemId][rank - 2]));
126129
multiDimOffset[rank - 1] =

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
223223
unsigned executionSize = dpasLayout.getExecutionSize();
224224
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
225225

226+
unsigned rank = shape.size();
226227
unsigned numRowsPerPackedValue = 0u, numColsPerPackedValue = 0u;
227228
unsigned numColsPerLaneForPackedValue = 0u, numOpsPerPackedValue = 0u;
228229
switch (opIdx) {
@@ -232,7 +233,7 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
232233
SmallVector<unsigned> shapeA = dpasLayout.getShapeA();
233234
// Unlike the operand B, to pack the value to i16 for scalar bit width <=16.
234235
numOpsPerPackedValue = opsPerChannel == 4 ? 2 : 1;
235-
unsigned packedColNum = shapeA[1] / numOpsPerPackedValue;
236+
unsigned packedColNum = shapeA[rank - 1] / numOpsPerPackedValue;
236237
// Each value name represent multiple rows if warpSize > packedColNum
237238
numRowsPerPackedValue = mlir::ceil(warpSize, packedColNum);
238239
numColsPerPackedValue = std::min(warpSize, packedColNum);
@@ -256,9 +257,9 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
256257
int64_t numRepK = numReps[opIdx ? 1 : 2];
257258

258259
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
259-
unsigned repClusterSize = repCluster[opIdx];
260+
unsigned repClusterSize = repCluster[opIdx ? rank - 1 : rank - 2];
260261

261-
for (unsigned dimOuter = 0; dimOuter < numRepOuter; ++dimOuter)
262+
for (unsigned repOuter = 0; repOuter < numRepOuter; ++repOuter)
262263
for (unsigned k = 0; k < numRepK; ++k)
263264
for (unsigned rep = 0; rep < repClusterSize; ++rep) {
264265
for (unsigned elemId = 0; elemId < numElemPerInstPerThread; ++elemId) {
@@ -268,9 +269,9 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
268269
(opIdx == 0) ? elemId % numOpsPerPackedValue : 0;
269270
unsigned packedElemId = elemId / numOpsPerPackedValue;
270271
unsigned repRowIndex =
271-
shapePerCTATile[0] * (opIdx == 0 ? dimOuter : k);
272+
shapePerCTATile[rank - 2] * (opIdx == 0 ? repOuter : k);
272273
unsigned repColIndex =
273-
shapePerCTATile[1] * (opIdx == 0 ? k : dimOuter);
274+
shapePerCTATile[rank - 1] * (opIdx == 0 ? k : repOuter);
274275
unsigned repClusterRowIndex = opIdx == 0 ? rep * instShape[0] : 0;
275276
unsigned repClusterColIndex = opIdx == 0 ? 0 : rep * instShape[1];
276277
unsigned packedElemRowIndex =
@@ -279,10 +280,17 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
279280
unsigned packedElemColIndex =
280281
(packedElemId % numColsPerLaneForPackedValue) *
281282
numColsPerPackedValue;
282-
offsets.push_back({repRowIndex + repClusterRowIndex +
283-
packedElemRowIndex + opsRowIndex,
284-
repColIndex + repClusterColIndex +
285-
packedElemColIndex + opsColIndex});
283+
if (rank == 3)
284+
offsets.push_back({0,
285+
repRowIndex + repClusterRowIndex +
286+
packedElemRowIndex + opsRowIndex,
287+
repColIndex + repClusterColIndex +
288+
packedElemColIndex + opsColIndex});
289+
else
290+
offsets.push_back({repRowIndex + repClusterRowIndex +
291+
packedElemRowIndex + opsRowIndex,
292+
repColIndex + repClusterColIndex +
293+
packedElemColIndex + opsColIndex});
286294
}
287295
}
288296

@@ -560,6 +568,7 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
560568

561569
inline SmallVector<SmallVector<unsigned>>
562570
emitOffsetForLayout(Attribute layout, RankedTensorType type) {
571+
std::cout << "~! emitOffsetForLayout\n";
563572
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout))
564573
return emitOffsetForDpasLayout(dpasLayout, type);
565574
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))

0 commit comments

Comments
 (0)