Skip to content

Commit 7568a4d

Browse files
authored
[BACKEND] Support vectorisation and arbitrary bitwidth in stmatrix (#6899)
As per title. I'll add transpose + generic support for ldmatrix in a different pr
1 parent f810652 commit 7568a4d

File tree

3 files changed

+53
-20
lines changed

3 files changed

+53
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ llvm-project-*/
1010
python/build/
1111
python/dist/
1212
python/triton*.egg-info/
13+
python/triton_kernels/triton*.egg-info/
1314

1415
python/triton/_C/*.pyd
1516
python/triton/_C/*.so

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
301301
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
302302
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
303303
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<64x32xf16, #linear>) {
304-
// CHECK-COUNT-2: nvgpu.stmatrix
304+
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
305305
// CHECK: llvm.return
306306
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
307307
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
@@ -323,7 +323,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
323323
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
324324
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
325325
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<32x32xf16, #linear>) {
326-
// CHECK-COUNT-2: nvgpu.stmatrix
326+
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
327327
// CHECK: llvm.return
328328
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
329329
ttg.local_store %a, %b : tensor<32x32xf16, #linear> -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
@@ -333,6 +333,38 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
333333

334334
// -----
335335

336+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
337+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [8, 0]], lane = [[0, 4], [0, 8], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}>
338+
#smem = #ttg.shared_memory
339+
// CHECK-LABEL: linear_to_swizzled_st_matrix_x2_local_store_fp8
340+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
341+
tt.func @linear_to_swizzled_st_matrix_x2_local_store_fp8(%a: tensor<64x16xf8E4M3FNUZ, #linear>) {
342+
// CHECK-COUNT-1: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}} :
343+
// CHECK: llvm.return
344+
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable>
345+
ttg.local_store %a, %b : tensor<64x16xf8E4M3FNUZ, #linear> -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable>
346+
tt.return
347+
}
348+
}
349+
350+
// -----
351+
352+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
353+
#linear = #ttg.linear<{register = [[8, 0], [0, 4], [0, 8]], lane = [[0, 1], [0, 2], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}>
354+
#smem = #ttg.shared_memory
355+
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store_fp32
356+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
357+
tt.func @linear_to_swizzled_st_matrix_local_store_fp32(%a: tensor<64x16xf32, #linear>) {
358+
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
359+
// CHECK: llvm.return
360+
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable>
361+
ttg.local_store %a, %b : tensor<64x16xf32, #linear> -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable>
362+
tt.return
363+
}
364+
}
365+
366+
// -----
367+
336368
#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
337369
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
338370
tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} {

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,9 @@ LogicalResult lowerDistributedToSharedStmatrix(
159159
auto kBlock = S("block");
160160
auto kOffset = S("offset");
161161
auto smemPtrTy = ptr_ty(ctx, 3);
162-
163-
// Just stmatrix for now
164-
// 1) NYI in the stmatrix lowering
165-
// Pack everything into uint32_t to support bitwidths other than 16
166162
auto bitwidth = tensorTy.getElementTypeBitWidth();
167-
if (bitwidth != 16)
163+
if (bitwidth > 32)
168164
return failure();
169-
170165
// Inter block stmatrix is not supported
171166
if (cvt.hasInDim(kBlock))
172167
return failure();
@@ -198,13 +193,9 @@ LogicalResult lowerDistributedToSharedStmatrix(
198193
auto reps = zerosLike(tile) * quot;
199194
assert(reps.getOutDimSize(kOffset) == cvt.getOutDimSize(kOffset));
200195

201-
// Choose the 4 elements indexed by the next to bases as the vectorisation
202-
// factor
196+
// Choose up to 4 packs of 32-bit elements indexed by the next to bases
197+
// as the vectorisation factor
203198
auto vec = std::min(2, quot.getInDimSizeLog2(kReg));
204-
// 2) NYI stmatrix.x1 and stmatrix.x2
205-
if (vec != 2) {
206-
return failure();
207-
}
208199

209200
// FIXME(Lezcano): Should we bail if any of the other 3 lane bases is zero?
210201

@@ -237,17 +228,26 @@ LogicalResult lowerDistributedToSharedStmatrix(
237228
.second;
238229

239230
// Elements per op
240-
auto step = (1 << vec) * (32 / bitwidth);
231+
auto nVecs = 1 << vec;
232+
auto elemsPerVec = 32 / bitwidth;
233+
auto step = nVecs * elemsPerVec;
241234
for (int i = 0; i < srcVals.size(); i += step) {
242235
auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second;
243236
Value offset = b.xor_(regBase, b.i32_val(regIdx));
244237
auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset,
245238
LLVM::GEPNoWrapFlags::inbounds);
246-
SmallVector<Value> inValsVec;
247-
for (int j = 0; j < step; j++)
248-
inValsVec.push_back(srcVals[i + j]);
249-
Value valsVec = packLLVector(loc, inValsVec, rewriter);
250-
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
239+
// Pack into vector of i32
240+
SmallVector<Value> inputs;
241+
Type packedTy = vec_ty(llvmElemTy, 32 / bitwidth);
242+
for (int j = 0; j < nVecs; j++) {
243+
Value input = b.undef(packedTy);
244+
for (int k = 0; k < elemsPerVec; k++) {
245+
input = b.insert_element(
246+
packedTy, input, srcVals[i + j * elemsPerVec + k], b.i32_val(k));
247+
}
248+
inputs.push_back(b.bitcast(input, i32_ty));
249+
}
250+
rewriter.create<triton::nvgpu::StoreMatrixOp>(loc, vecAddr, inputs);
251251
}
252252
return success();
253253
}

0 commit comments

Comments
 (0)