Skip to content

Commit fbf9fc1

Browse files
committed
Fix issue in DotOp A layout with DPAS to LL layout conversion.
1 parent 51a0ade commit fbf9fc1

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3201,10 +3201,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32013201
if (!layout)
32023202
return "";
32033203

3204-
unsigned threadsPerWarp = getWarpSize(layout);
3205-
unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout);
3206-
unsigned numBlocks = getNumCTAs(layout);
3207-
int numElementsPerThreads = getTotalElemsPerThread(tensorType);
32083204
StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
32093205
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
32103206
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
@@ -3217,6 +3213,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32173213
int64_t tensorSize = product(tensorType.getShape());
32183214
std::vector<std::string> elementMapping(tensorSize);
32193215
std::vector<std::string> threadMapping;
3216+
unsigned threadsPerWarp = ll->getInDimSize(kLane);
3217+
unsigned numWarpsPerCTA = ll->getInDimSize(kWarp);
3218+
unsigned numBlocks = ll->getInDimSize(kBlock);
3219+
int numElementsPerThreads = ll->getInDimSize(kRegister);
32203220
for (int blockId = 0; blockId < numBlocks; ++blockId) {
32213221
for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) {
32223222
for (int tid = 0; tid < threadsPerWarp; ++tid) {

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -341,38 +341,44 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
341341

342342
} // anonymous namespace
343343

344+
// clang-format off
344345
// The layout example repeat_count=8, systolic_depth=8,
345346
// execution_size=16 and operands_per_chan=2 for warp size 32.
346347
// For A operand:
347-
// systolic depth = 8
348-
//<----------------------------------------------------->
349-
// opsPerChan=2
350-
//<--------->
351-
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^
352-
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
353-
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
354-
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8
355-
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 |
356-
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
357-
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
358-
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v
348+
// K = 16 (K = systolic depth * opsPerChan)
349+
// <---------------------------------------------------------------------------->
350+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
351+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
352+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
353+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
354+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count)
355+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
356+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
357+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v
359358
// In this case, the LinearLayout bases are:
360-
// Register: {{0,1}, {4,0}}
361-
// Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}}
359+
// Register: {{2,0}, {4,0}}
360+
// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}}
361+
// clang-format on
362362
std::vector<std::vector<int32_t>> DPASRegBasesA(int opsPerChannel,
363363
int repeatCount,
364364
int threadsPerWarp,
365365
int systolicDepth) {
366-
int rowPerWarp = threadsPerWarp / systolicDepth;
367-
int warpRepeats = repeatCount / rowPerWarp;
368366
std::vector<std::vector<int32_t>> regBases;
369367

370-
for (int opc = 1; opc < opsPerChannel; opc *= 2) {
368+
// pack the value to i16 for scalar bit width <=16.
369+
assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) &&
370+
"invalid opsPerChannel number.");
371+
int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
372+
int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane;
373+
int rowsPerWarp = mlir::ceil<int>(threadsPerWarp, packedColNum);
374+
int warpRepeats = repeatCount / rowsPerWarp;
375+
376+
for (int opc = 1; opc < packedOpsPerLane; opc *= 2) {
371377
regBases.push_back({0, opc});
372378
}
373379

374380
for (int warp = 1; warp < warpRepeats; warp *= 2) {
375-
regBases.push_back({warp * rowPerWarp, 0});
381+
regBases.push_back({warp * rowsPerWarp, 0});
376382
}
377383

378384
return regBases;
@@ -382,11 +388,17 @@ std::vector<std::vector<int32_t>>
382388
DPASLaneBasesA(int opsPerChannel, int threadsPerWarp, int systolicDepth) {
383389
std::vector<std::vector<int32_t>> laneBases;
384390

385-
for (int tid = 1; tid < systolicDepth; tid *= 2) {
386-
laneBases.push_back({0, opsPerChannel * tid});
391+
// pack the value to i16 for scalar bit width <=16.
392+
assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) &&
393+
"invalid opsPerChannel number.");
394+
int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
395+
int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane;
396+
397+
for (int tid = 1; tid < packedColNum; tid *= 2) {
398+
laneBases.push_back({0, packedOpsPerLane * tid});
387399
}
388-
for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2) {
389-
laneBases.push_back({tid / systolicDepth, 0});
400+
for (int tid = packedColNum; tid < threadsPerWarp; tid *= 2) {
401+
laneBases.push_back({tid / packedColNum, 0});
390402
}
391403

392404
return laneBases;
@@ -602,8 +614,7 @@ std::optional<LinearLayout>
602614
dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,
603615
ArrayRef<int64_t> shape) {
604616
auto dpasLayout = cast<intel::DpasEncodingAttr>(dotDpasLayout.getParent());
605-
if (dotDpasLayout.getOpIdx() == 0)
606-
return std::nullopt;
617+
607618
return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx());
608619
}
609620

0 commit comments

Comments
 (0)