Skip to content

Commit 7f06338

Browse files
authored
[BACKEND] Fix a special case where elements along the k dimension are repeated within each thread (#5121)
This PR includes the following changes: - Adds comprehensive tests for mixed-precision dot products, including configurations such as f8xf16, i8xf16, f8xf32, and i8xf32. - Fixes mmav2 when the k dimension contains duplicated elements. For example, with a 16x16 fp16 triton tensor (opidx=0, kwidth=4), a 16x32 tile is used, causing the first 16 elements in the k dimension to repeat in the last 16 elements. During mmav2 computation, only the first half is required.
1 parent bb71ced commit 7f06338

File tree

4 files changed

+132
-36
lines changed

4 files changed

+132
-36
lines changed

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,17 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
138138

139139
// FIXME [Dot LL]
140140
// Do for all DotOperandEncodingAttr once we have LLs for all of them
141-
static bool isSupportedDotOpLayout(Attribute layout) {
141+
static bool isSupportedDotOpLayout(RankedTensorType type) {
142+
auto layout = type.getEncoding();
143+
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
142144
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
145+
auto kWidth = dot.getKWidth();
143146
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
144147
// - kWidth == 8
148+
// - kWidth == 4, bitwidth = 32
145149
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
146-
bool legacyLoweringIsBuggy = dot.getKWidth() >= 8;
150+
bool legacyLoweringIsBuggy =
151+
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
147152
return legacyLoweringIsBuggy && mma.isAmpere();
148153
}
149154
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
@@ -162,7 +167,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
162167
if (isa<SharedEncodingAttr>(srcLayout) &&
163168
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
164169
dstLayout) ||
165-
isSupportedDotOpLayout(dstLayout))) {
170+
isSupportedDotOpLayout(dstTy))) {
166171
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
167172
rewriter);
168173
}
@@ -202,7 +207,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
202207
auto dstShape = dstTy.getShape();
203208
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
204209
auto dstLayout = dstTy.getEncoding();
205-
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
210+
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
206211
"Unexpected rank of ConvertLayout(shared->distributed)");
207212

208213
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(

python/test/regression/test_cast_matmul.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
import triton.language as tl
1414

1515
input_dtypes = ["float16", "float32", "float64"]
16+
if triton.runtime.driver.active.get_current_target().backend == "cuda":
17+
input_dtypes += ["int8", "float8_e5m2"]
18+
cc = torch.cuda.get_device_capability(0)
19+
if cc >= (8, 9):
20+
input_dtypes += ["float8_e4m3fn"]
1621
out_dtypes = ["float16", "float32"]
1722

1823

@@ -63,37 +68,48 @@ def matmul_kernel(A, B, C, M, N, K, #
6368
tl.store(C, acc, mask=mask)
6469

6570

66-
@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
67-
[(M, K, N, w, x, o) #
68-
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
71+
@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype",
72+
[(M, K, N, BLOCK_K, w, x, o) #
73+
for BLOCK_K in [16, 32] #
74+
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
6975
for w in input_dtypes
7076
for x in input_dtypes #
7177
for o in out_dtypes])
72-
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
78+
def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype):
7379
if x_dtype == w_dtype:
7480
pytest.skip("skip the same input dtype")
7581
device = torch.cuda.current_device()
76-
x_dtype = getattr(torch, x_dtype)
77-
w_dtype = getattr(torch, w_dtype)
78-
a = torch.randn((M, K), device=device, dtype=x_dtype)
79-
b = torch.randn((K, N), device=device, dtype=w_dtype)
82+
x_dtype: torch.dtype = getattr(torch, x_dtype)
83+
w_dtype: torch.dtype = getattr(torch, w_dtype)
84+
85+
def init_tensor(dtype, shape):
86+
if dtype == torch.int8:
87+
return torch.randint(0, 3, shape, device=device, dtype=dtype)
88+
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
89+
return torch.randn(shape, device=device, dtype=torch.float16).to(dtype)
90+
else:
91+
return torch.randn(shape, device=device, dtype=dtype)
92+
93+
a = init_tensor(w_dtype, (M, K))
94+
b = init_tensor(x_dtype, (K, N))
95+
8096
torch_dtype = getattr(torch, out_dtype)
8197
triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype
8298
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
8399
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)
84100

85101
# launch kernel
86-
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
87-
grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1)
102+
block_m, block_n, block_k = 16, 16, BLOCK_K
103+
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1)
88104

89105
matmul_kernel[grid](
90106
a, b, out_triton, M, N, K, #
91107
a.stride(0), a.stride(1), #
92108
b.stride(0), b.stride(1), #
93109
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
94110
GROUP_M=8, #
95-
BLOCK_M=BLOCK_M, #
96-
BLOCK_N=BLOCK_N, #
97-
BLOCK_K=BLOCK_K)
111+
BLOCK_M=block_m, #
112+
BLOCK_N=block_n, #
113+
BLOCK_K=block_k)
98114

99115
torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,6 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value lane,
226226
Value cSwizzleOffset) {
227227
Value warpB = multiDimWarpId[0];
228228
Value warpOff = kOrder == 2 ? multiDimWarpId[1] : multiDimWarpId[2];
229-
int cTileShape = tileShape[order[0]];
230-
int sTileShape = tileShape[order[1]];
231-
if (!needTrans) {
232-
std::swap(cTileShape, sTileShape);
233-
}
234229

235230
SmallVector<Value> offs(numPtrs);
236231

@@ -239,7 +234,6 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value lane,
239234
int laneHeight = 8;
240235
int quadWidth = laneWidth * kWidth;
241236
int quadHeight = laneHeight;
242-
int numQuadI = 2;
243237

244238
// outer index base
245239
Value iBase = udiv(lane, i32_val(laneWidth));
@@ -544,12 +538,15 @@ Value composeValuesToDotOperandLayoutStruct(
544538
// unpacked into individual elements.
545539
// `kIters` specifies the number of contiguous int32 elements each thread
546540
// should load.
547-
auto kIters = isHopper ? 1 : kWidth / (32 / bitwidth);
541+
// `kSize` specifies the total number of int32 elements each thread should
542+
// load.
543+
int kIters = isHopper ? 1 : kWidth / (32 / bitwidth);
544+
int kSize = repK >= kIters ? repK * 2 : kIters;
548545

549546
std::vector<Value> elems;
550547
auto unpackVec = [&](int b, int m, int k) {
551-
for (auto kIter = 0; kIter < kIters; ++kIter) {
552-
auto val = vals.at({b, m, k + kIter});
548+
for (int kIter = 0; kIter < kIters; ++kIter) {
549+
auto val = vals.at({b, m, (k + kIter) % kSize});
553550
auto vec = bitcast(val, vecTy);
554551
for (auto i = 0; i < numElemsPerVec; ++i) {
555552
elems.push_back(extract_element(eltTy, vec, i32_val(i)));

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
9090
// we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the
9191
// K dimension.
9292
llvm::SmallVector<unsigned> si;
93+
auto kIters = kWidth / (32 / bitwidth);
9394

9495
if (dot.getOpIdx() == 0) {
9596
// Original register layout:
@@ -106,11 +107,63 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
106107
// 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]]
107108
// 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]]
108109
// 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]]
109-
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
110-
for (size_t tile = 0; tile < 4; ++tile)
111-
for (size_t e = 0; e < numElemsPerVec; ++e) {
112-
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
113-
}
110+
if (kIters <= repK) {
111+
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
112+
for (size_t tile = 0; tile < 4; ++tile)
113+
for (size_t e = 0; e < numElemsPerVec; ++e) {
114+
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
115+
}
116+
} else {
117+
// Suppose kWidth=4 and type=fp32, so numElemsPerVec=1.
118+
// Each tile of the dot operand layout has a size of 16x32.
119+
// However, if the triton tensor size is 16x16, elements along the k
120+
// dimension are duplicated. Within each tile, each register
121+
// contains 2x8 elements arranged as follows:
122+
//
123+
// tile0/0 tile0/1
124+
// |<--kWidth=4-->| |<--kWidth-->|
125+
// |<-mmaWidth=2->|
126+
// [0, 1, 2, 3] [0, 1, 2, 3]
127+
// [4, 5, 6, 7] [4, 5, 6, 7]
128+
//
129+
// tile0/1 replicates the elements in tile0/0 along the k dimension.
130+
// For a tensor size of 32x32, the next tile on the m dimension is as
131+
// follows:
132+
//
133+
// tile1/0 tile1/1
134+
// |<--kWidth-->| |<--kWidth-->|
135+
// [8, 9, 10, 11], [8, 9, 10, 11]
136+
// [12, 13, 14, 15], [12, 13, 14, 15]
137+
//
138+
// Within a single tile, we can perform two MMAs, and the
139+
// resulting register layout for each MMA is as follows:
140+
//
141+
// 1st MMA: [0, 4, 1, 5]
142+
// 2nd MMA: [2, 6, 3, 7]
143+
// 3rd MMA: [8, 12, 9, 13]
144+
// 4th MMA: [10, 14, 11, 15]
145+
//
146+
// Additionally, we should reorder the elements by moving the duplicated
147+
// elements to the end. In the example above, we convert the order from
148+
// tile0/0, tile0/1, tile1/0, tile1/1 to tile0/0, tile1/0, tile0/1,
149+
// tile1/1, so that only the first two tiles will be used in the
150+
// computation.
151+
size_t elemsPerTile = 2 * 2 * kWidth;
152+
size_t elemsPerMma = 2 * 2 * numElemsPerVec;
153+
size_t mmaWidth = kWidth / numElemsPerVec / 2;
154+
size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma);
155+
for (size_t rep = 0; rep < repMma; ++rep)
156+
for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile)
157+
for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth)
158+
for (size_t kTile = 0; kTile < 2; ++kTile)
159+
for (size_t mTile = 0; mTile < 2; ++mTile)
160+
for (size_t e = 0; e < numElemsPerVec; ++e) {
161+
si.push_back(rep * mmaWidth * elemsPerMma +
162+
mmaKWidth * 2 * numElemsPerVec +
163+
tile * elemsPerTile + mTile * kWidth +
164+
kTile * numElemsPerVec + e);
165+
}
166+
}
114167
} else {
115168
// Original register layout:
116169
//
@@ -122,11 +175,36 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
122175
// 2nd MMA: [[2, 3], [10, 11]]
123176
// 3rd MMA: [[4, 5], [12, 13]]
124177
// 4th MMA: [[6, 7], [14, 15]]
125-
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
126-
for (size_t tile = 0; tile < 2; ++tile)
127-
for (size_t e = 0; e < numElemsPerVec; ++e) {
128-
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
129-
}
178+
if (kIters <= repK) {
179+
for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep)
180+
for (size_t tile = 0; tile < 2; ++tile)
181+
for (size_t e = 0; e < numElemsPerVec; ++e) {
182+
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
183+
}
184+
} else {
185+
// Suppose kWidth=4 and type=fp32.
186+
// Original register layout:
187+
//
188+
// tile0/0 tile0/1
189+
// [0, 1, 2, 3]^T, [0, 1, 2, 3]^T
190+
//
191+
// Similar to the opIdx=0 situation, we should reorder the elements by
192+
// moving the duplicated elements to the end.
193+
size_t elemsPerTile = 2 * kWidth;
194+
size_t elemsPerMma = 2 * numElemsPerVec;
195+
size_t mmaWidth = kWidth / numElemsPerVec / 2;
196+
size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma);
197+
for (size_t rep = 0; rep < repMma; ++rep)
198+
for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile)
199+
for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth)
200+
for (size_t kTile = 0; kTile < 2; ++kTile)
201+
for (size_t e = 0; e < numElemsPerVec; ++e) {
202+
si.push_back(rep * mmaWidth * elemsPerMma +
203+
mmaKWidth * 2 * numElemsPerVec +
204+
tile * elemsPerTile + kTile * numElemsPerVec +
205+
e);
206+
}
207+
}
130208
}
131209

132210
auto step = si.size();

0 commit comments

Comments
 (0)