Skip to content

Commit 1f61b2f

Browse files
committed
Address review comments.
1 parent c0b96e8 commit 1f61b2f

File tree

4 files changed

+38
-38
lines changed

4 files changed

+38
-38
lines changed

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ DpasEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
153153
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
154154
size_t rank = shapeC.size();
155155
SmallVector<unsigned> shapePerCTATile(rank);
156-
for (size_t i = 0; i < rank; ++i) {
157-
shapePerCTATile[i] = shapeC[i] * warpsPerCTA[i];
158-
}
156+
llvm::transform(
157+
llvm::zip_equal(shapeC, warpsPerCTA), shapePerCTATile.begin(),
158+
[](auto entry) { return std::get<0>(entry) * std::get<1>(entry); });
159159
return shapePerCTATile;
160160
}
161161

@@ -220,7 +220,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
220220
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
221221
warpsPerCTA[rank - 2])),
222222
std::max<int64_t>(1, shape[rank - 1] / shapePerWarp[rank - 1])};
223-
} else if (opIdx == 1) {
223+
}
224+
225+
if (opIdx == 1) {
224226
auto shapePerWarp = getShapeB();
225227
int64_t numRepBatch =
226228
rank == 3 ? std::max<int64_t>(1, shape[0] /
@@ -230,28 +232,27 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
230232
std::max<int64_t>(1, shape[rank - 2] / shapePerWarp[rank - 2]),
231233
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
232234
warpsPerCTA[rank - 1]))};
233-
} else {
234-
assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)");
235-
auto shapePerWarp = getShapeC();
236-
int64_t numRepBatch =
237-
rank == 3 ? std::max<int64_t>(1, shape[0] /
238-
(shapePerWarp[0] * warpsPerCTA[0]))
239-
: 1;
240-
return {numRepBatch,
241-
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
242-
warpsPerCTA[rank - 2])),
243-
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
244-
warpsPerCTA[rank - 1]))};
245235
}
246-
return rep;
236+
237+
assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)");
238+
auto shapePerWarp = getShapeC();
239+
int64_t numRepBatch =
240+
rank == 3
241+
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
242+
: 1;
243+
return {numRepBatch,
244+
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
245+
warpsPerCTA[rank - 2])),
246+
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
247+
warpsPerCTA[rank - 1]))};
247248
}
248249

249250
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
250251
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
251252
auto shapePerCTA = getShapePerCTA(*this, shape);
252253
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
253254
auto threadsPerWar = getSubGroupSize();
254-
int rank = shape.size();
255+
size_t rank = shape.size();
255256
if (opIdx == 0) {
256257
auto shapeA = getShapeA();
257258
auto totalElem = product<unsigned>(shapeA);
@@ -269,16 +270,12 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
269270

270271
SmallVector<unsigned> DpasEncodingAttr::getWarpOrder() const {
271272
size_t rank = getWarpsPerCTA().size();
272-
SmallVector<unsigned> order(rank);
273-
std::iota(order.rbegin(), order.rend(), 0);
274-
return order;
273+
return llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
275274
}
276275

277276
SmallVector<unsigned> DpasEncodingAttr::getThreadOrder() const {
278277
size_t rank = getWarpsPerCTA().size();
279-
SmallVector<unsigned> order(rank);
280-
std::iota(order.rbegin(), order.rend(), 0);
281-
return order;
278+
return llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
282279
}
283280

284281
SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
@@ -307,18 +304,18 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
307304
size_t rank = parentShapePerCTATile.size();
308305
if (opIdx == 0) {
309306
auto shapeA = getShapeA();
310-
if (rank == 2)
311-
return {parentShapePerCTATile[0], shapeA[1]};
312-
else
313-
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
314-
shapeA[rank - 1]};
307+
return (rank == 2)
308+
? SmallVector<unsigned>{parentShapePerCTATile[0], shapeA[1]}
309+
: SmallVector<unsigned>{parentShapePerCTATile[0],
310+
parentShapePerCTATile[rank - 2],
311+
shapeA[rank - 1]};
315312
} else if (opIdx == 1) {
316313
auto shapeB = getShapeB();
317-
if (rank == 2)
318-
return {shapeB[0], parentShapePerCTATile[1]};
319-
else
320-
return {parentShapePerCTATile[0], shapeB[rank - 2],
321-
parentShapePerCTATile[rank - 1]};
314+
return (rank == 2)
315+
? SmallVector<unsigned>{shapeB[0], parentShapePerCTATile[1]}
316+
: SmallVector<unsigned>{parentShapePerCTATile[0],
317+
shapeB[rank - 2],
318+
parentShapePerCTATile[rank - 1]};
322319
} else {
323320
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
324321
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ struct ConvertLayoutOpConversion
112112
return multiDimOffset;
113113
}
114114
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout)) {
115-
assert(rank == 2 || rank == 3);
115+
assert((rank == 2 || rank == 3) &&
116+
"unexpected rank number for Dpas layout");
116117
auto multiDimBase = ::intel::emitBaseIndexForLayout(
117118
loc, rewriter, targetInfo, layout, type, false);
118119
SmallVector<SmallVector<unsigned>> offsets;

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
283283

284284
auto sharedLayout = cast<SharedEncodingAttr>(descTy.getEncoding());
285285
ArrayRef<unsigned> order = sharedLayout.getOrder();
286-
unsigned rank = order.size();
286+
size_t rank = order.size();
287287

288288
// (a, b) is the coordinate.
289289
auto load = [=, &rewriter, &smemObj, &shapePerWarp, &multiDimWarpId,

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,13 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
285285
packedElemRowIndex + opsRowIndex,
286286
repColIndex + repClusterColIndex +
287287
packedElemColIndex + opsColIndex});
288-
else
288+
else {
289+
assert((rank == 2) && "unexpected rank number for Dot layout");
289290
offsets.push_back({repRowIndex + repClusterRowIndex +
290291
packedElemRowIndex + opsRowIndex,
291292
repColIndex + repClusterColIndex +
292293
packedElemColIndex + opsColIndex});
294+
}
293295
}
294296
}
295297

@@ -415,7 +417,7 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
415417
Value warpId = udiv(threadId, warpSize);
416418
Value laneId = urem(threadId, warpSize);
417419

418-
unsigned rank = type.getShape().size();
420+
size_t rank = type.getShape().size();
419421
auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
420422
ArrayRef<int64_t> shape = type.getShape();
421423

0 commit comments

Comments
 (0)