Skip to content

Commit 134494d

Browse files
Merge OpenAI Triton commit 336cc1d (#4560)
This PR change the Triton base from 34758e4 to 336cc1d (Jun 18). Pass rate: 97.12%
2 parents 50fc4c3 + 89234eb commit 134494d

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
309309
auto totalStoreCvt = srcLayout.invertAndCompose(smem);
310310
auto totalLoadCvt = dstLayout.invertAndCompose(smem);
311311

312-
// FIXME(Lezcano): The legacy path also creates PRMT, so we should revisit
313-
314312
// The permutation exists by construction of the reps dimension in
315313
// optimalSwizzling
316314
auto permStore =

lib/Tools/GenericSwizzling.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,9 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
311311
// Bits in a bank segment: 32 banks x 32 bits
312312
constexpr int32_t bankBits = 32 * 32;
313313
// Bases needed to cover a whole bank segment
314-
const int32_t lenBbasis =
315-
llvm::Log2_32(bankBits / ((1 << vbasis.size()) * bitwidth));
314+
const int32_t lenBbasis = std::min<int32_t>(
315+
llvm::Log2_32(bankBits / ((1 << vbasis.size()) * bitwidth)),
316+
dim - vbasis.size());
316317
// Bases to cover all the tensor
317318
const int32_t lenSbasis = dim - lenBbasis - vbasis.size();
318319

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ def _p_matmul_ogs(
384384
block_shape=[BLOCK_M, OUT_BLOCK_N],
385385
)
386386

387+
# bias + scale
388+
offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
389+
mask_n = offs_y_n < N
390+
if B is not None:
391+
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
392+
if pid_k1 == 0:
393+
bias = tl.load(BPtrs, mask=mask_n, other=0)
394+
else:
395+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
396+
else:
397+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
387398
if Betas is not None:
388399
betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
389400
else:
@@ -399,15 +410,21 @@ def _p_matmul_ogs(
399410
w_scale = load_scale(WScale)
400411

401412
accs = (acc,)
413+
biases = (bias,)
402414

403415
if SUBTILE_FACTOR >= 2:
404416
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
405417
accs = (acc0, acc1)
418+
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
419+
biases = (bias0, bias1)
406420

407421
if SUBTILE_FACTOR >= 4:
408422
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
409423
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
410424
accs = (acc00, acc01, acc10, acc11)
425+
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
426+
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
427+
biases = (bias00, bias01, bias10, bias11)
411428

412429
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
413430
tl.static_assert(len(accs) == SUBTILE_FACTOR)
@@ -419,18 +436,7 @@ def _p_matmul_ogs(
419436
if SWAP_XW:
420437
acc_tile = acc_tile.T
421438

422-
if B is not None:
423-
offs_y_n = off_n1 + EPILOGUE_BLOCK_N * a_i + tl.arange(0, EPILOGUE_BLOCK_N)
424-
mask_n = offs_y_n < N
425-
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
426-
if pid_k1 == 0:
427-
bias = tl.load(BPtrs, mask=mask_n, other=0)
428-
else:
429-
bias = tl.full([EPILOGUE_BLOCK_N], 0, dtype=tl.float32)
430-
else:
431-
bias = tl.full([EPILOGUE_BLOCK_N], 0, dtype=tl.float32)
432-
433-
acc_tile = acc_tile + bias[None, :] * betas[:, None]
439+
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
434440
if out_alpha is not None:
435441
acc_tile *= out_alpha
436442

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
12931293
// CHECK-LABEL: linear_layout_with_multiple_iterations
12941294
tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) {
12951295
%cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1>
1296-
// CHECK-COUNT-2: llvm.store {{.*}} : vector<2xi16>
1296+
// CHECK-COUNT-1: llvm.store {{.*}} : vector<4xi16>
12971297
// CHECK: nvvm.barrier0
12981298
// CHECK-COUNT: llvm.load{{.*}}->vector<2xi16>
12991299
tt.return

0 commit comments

Comments
 (0)