Skip to content

Commit fb0efdf

Browse files
committed
Fix issue in DotOp A layout with DPAS to LL layout conversion.
1 parent 51a0ade commit fb0efdf

File tree

3 files changed

+85
-34
lines changed

3 files changed

+85
-34
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3201,10 +3201,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32013201
if (!layout)
32023202
return "";
32033203

3204-
unsigned threadsPerWarp = getWarpSize(layout);
3205-
unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout);
3206-
unsigned numBlocks = getNumCTAs(layout);
3207-
int numElementsPerThreads = getTotalElemsPerThread(tensorType);
32083204
StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
32093205
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
32103206
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
@@ -3217,6 +3213,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32173213
int64_t tensorSize = product(tensorType.getShape());
32183214
std::vector<std::string> elementMapping(tensorSize);
32193215
std::vector<std::string> threadMapping;
3216+
unsigned threadsPerWarp = ll->getInDimSize(kLane);
3217+
unsigned numWarpsPerCTA = ll->getInDimSize(kWarp);
3218+
unsigned numBlocks = ll->getInDimSize(kBlock);
3219+
int numElementsPerThreads = ll->getInDimSize(kRegister);
32203220
for (int blockId = 0; blockId < numBlocks; ++blockId) {
32213221
for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) {
32223222
for (int tid = 0; tid < threadsPerWarp; ++tid) {

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -341,38 +341,44 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
341341

342342
} // anonymous namespace
343343

344+
// clang-format off
344345
// The layout example repeat_count=8, systolic_depth=8,
345346
// execution_size=16 and operands_per_chan=2 for warp size 32.
346347
// For A operand:
347-
// systolic depth = 8
348-
//<----------------------------------------------------->
349-
// opsPerChan=2
350-
//<--------->
351-
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^
352-
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
353-
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
354-
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8
355-
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 |
356-
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
357-
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
358-
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v
348+
// K = 16 (K = systolic depth * opsPerChan)
349+
// <---------------------------------------------------------------------------->
350+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
351+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
352+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
353+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
354+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count)
355+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
356+
// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
357+
// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v
359358
// In this case, the LinearLayout bases are:
360-
// Register: {{0,1}, {4,0}}
361-
// Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}}
359+
// Register: {{2,0}, {4,0}}
360+
// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}}
361+
// clang-format on
362362
std::vector<std::vector<int32_t>> DPASRegBasesA(int opsPerChannel,
363363
int repeatCount,
364364
int threadsPerWarp,
365365
int systolicDepth) {
366-
int rowPerWarp = threadsPerWarp / systolicDepth;
367-
int warpRepeats = repeatCount / rowPerWarp;
368366
std::vector<std::vector<int32_t>> regBases;
369367

370-
for (int opc = 1; opc < opsPerChannel; opc *= 2) {
368+
// pack the value to i16 for scalar bit width <=16.
369+
assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) &&
370+
"invalid opsPerChannel number.");
371+
int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
372+
int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane;
373+
int rowsPerWarp = mlir::ceil<int>(threadsPerWarp, packedColNum);
374+
int warpRepeats = repeatCount / rowsPerWarp;
375+
376+
for (int opc = 1; opc < packedOpsPerLane; opc *= 2) {
371377
regBases.push_back({0, opc});
372378
}
373379

374380
for (int warp = 1; warp < warpRepeats; warp *= 2) {
375-
regBases.push_back({warp * rowPerWarp, 0});
381+
regBases.push_back({warp * rowsPerWarp, 0});
376382
}
377383

378384
return regBases;
@@ -382,11 +388,17 @@ std::vector<std::vector<int32_t>>
382388
DPASLaneBasesA(int opsPerChannel, int threadsPerWarp, int systolicDepth) {
383389
std::vector<std::vector<int32_t>> laneBases;
384390

385-
for (int tid = 1; tid < systolicDepth; tid *= 2) {
386-
laneBases.push_back({0, opsPerChannel * tid});
391+
// pack the value to i16 for scalar bit width <=16.
392+
assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) &&
393+
"invalid opsPerChannel number.");
394+
int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
395+
int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane;
396+
397+
for (int tid = 1; tid < packedColNum; tid *= 2) {
398+
laneBases.push_back({0, packedOpsPerLane * tid});
387399
}
388-
for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2) {
389-
laneBases.push_back({tid / systolicDepth, 0});
400+
for (int tid = packedColNum; tid < threadsPerWarp; tid *= 2) {
401+
laneBases.push_back({tid / packedColNum, 0});
390402
}
391403

392404
return laneBases;
@@ -602,8 +614,7 @@ std::optional<LinearLayout>
602614
dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,
603615
ArrayRef<int64_t> shape) {
604616
auto dpasLayout = cast<intel::DpasEncodingAttr>(dotDpasLayout.getParent());
605-
if (dotDpasLayout.getOpIdx() == 0)
606-
return std::nullopt;
617+
607618
return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx());
608619
}
609620

third_party/intel/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,47 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) {
5959
},
6060
{S("dim0"), S("dim1")}));
6161
// Test Operand A (opIdx=0)
62+
EXPECT_EQ(
63+
DPAStoLinearLayout({8, 32}, dpas({1, 1}, 8, 8, 16, 4, {1, 1}, 32), 0),
64+
LinearLayout(
65+
{
66+
{S("register"), {{0, 1}, {2, 0}, {4, 0}}},
67+
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}},
68+
{S("warp"), {}},
69+
{S("block"), {}},
70+
},
71+
{S("dim0"), S("dim1")}));
6272
EXPECT_EQ(
6373
DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 0),
6474
LinearLayout(
6575
{
66-
{S("register"), {{0, 1}, {4, 0}}},
67-
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
76+
{S("register"), {{2, 0}, {4, 0}}},
77+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
78+
{S("warp"), {}},
79+
{S("block"), {}},
80+
},
81+
{S("dim0"), S("dim1")}));
82+
EXPECT_EQ(
83+
DPAStoLinearLayout({8, 8}, dpas({1, 1}, 8, 8, 16, 1, {1, 1}, 32), 0),
84+
LinearLayout(
85+
{
86+
{S("register"), {{4, 0}}},
87+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {1, 0}, {2, 0}}},
6888
{S("warp"), {}},
6989
{S("block"), {}},
7090
},
7191
{S("dim0"), S("dim1")}));
7292
// Test Operand B (opIdx=1)
93+
EXPECT_EQ(
94+
DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 4, {1, 1}, 32), 1),
95+
LinearLayout(
96+
{
97+
{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}},
98+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}}},
99+
{S("warp"), {}},
100+
{S("block"), {}},
101+
},
102+
{S("dim0"), S("dim1")}));
73103
EXPECT_EQ(
74104
DPAStoLinearLayout({16, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 1),
75105
LinearLayout(
@@ -80,6 +110,16 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) {
80110
{S("block"), {}},
81111
},
82112
{S("dim0"), S("dim1")}));
113+
EXPECT_EQ(
114+
DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 1, {1, 1}, 32), 1),
115+
LinearLayout(
116+
{
117+
{S("register"), {{2, 0}, {4, 0}}},
118+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
119+
{S("warp"), {}},
120+
{S("block"), {}},
121+
},
122+
{S("dim0"), S("dim1")}));
83123
}
84124

85125
TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) {
@@ -98,8 +138,8 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) {
98138
DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 0),
99139
LinearLayout(
100140
{
101-
{S("register"), {{0, 1}, {4, 0}, {8, 0}, {16, 0}}},
102-
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
141+
{S("register"), {{2, 0}, {4, 0}, {8, 0}, {16, 0}}},
142+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
103143
{S("warp"), {}},
104144
{S("block"), {}},
105145
},
@@ -154,8 +194,8 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandA) {
154194
LinearLayout(
155195
{
156196
{S("register"),
157-
{{0, 1}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}},
158-
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
197+
{{2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}},
198+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
159199
{S("warp"), {{0, 0}, {32, 0}}},
160200
{S("block"), {}},
161201
},

0 commit comments

Comments
 (0)