Skip to content

Commit 7dfe5d8

Browse files
Merge OpenAI Triton commit 22188cc (#5136)
This PR change the Triton base from acd72ac to 22188cc (Sep 11). Pass rate: 98.8%
2 parents 0c2c9b3 + 50a63bc commit 7dfe5d8

File tree

14 files changed

+408
-303
lines changed

14 files changed

+408
-303
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,10 @@ void AxisInfoAnalysis::visitForOpInductionVar(
10851085
AxisInfo::DimVectorT knownContiguity(1, 1);
10861086
AxisInfo::DimVectorT knownDivisibility(1, 1);
10871087
AxisInfo::DimVectorT knownConstancy(1, 1);
1088-
knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0));
1088+
auto lbDivisibility = lb.getDivisibility();
1089+
auto stepDivisibility = step.getDivisibility();
1090+
if (!lbDivisibility.empty() && !stepDivisibility.empty())
1091+
knownDivisibility[0] = gcd(lbDivisibility[0], stepDivisibility[0]);
10891092
auto inductionVar =
10901093
AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
10911094
(void)argLattices[0]->join(inductionVar);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,12 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11511151

11521152
LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11531153
TensorMemoryEncodingAttr encoding) {
1154+
// [Zeros in TMEM LinearLayouts]
1155+
// If there is a zero in bases rows=32,64 this means that there is
1156+
// broadcasting, i.e. the same tensor element is duplicated in different
1157+
// addressable blocks If the zero is in any other row/col (i.e. within a given
1158+
// warp-addressable tmem space) it means it is not defined
1159+
11541160
// We model packed layouts as having the rows/cols dimensions of bitwidth=16
11551161
// This means that a layout with unpacked=True is the same as one with
11561162
// unpacked=False
@@ -1186,25 +1192,26 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11861192
auto blockM = encoding.getBlockM();
11871193
auto blockN = std::min<int32_t>(encoding.getBlockN(), shape[1]);
11881194
assert(blockM == 64 || blockM == 128);
1189-
LinearLayout tile;
1195+
LinearLayout tile =
1196+
LinearLayout::zeros1D(encoding.getColStride(), kCol, dims[1]);
11901197
if (blockM == 64) {
1191-
tile = LinearLayout::identity1D(16, kRow, dims[0]) *
1192-
LinearLayout::identity1D(blockN, kCol, dims[1]);
1198+
tile *= LinearLayout::identity1D(16, kRow, dims[0]) *
1199+
LinearLayout::identity1D(blockN, kCol, dims[1]);
11931200
auto bases = tile.getBases();
11941201
if (shape[0] > blockM) {
11951202
bases[kRow].push_back({64, 0});
11961203
} else if (shape[1] > blockN) {
11971204
bases[kRow].push_back({0, blockN});
11981205
} else {
1199-
// Empty. This is modelled as broadcasting, same as for TMA(fp4)
1206+
// Empty, meaning the element is not defined
12001207
bases[kRow].push_back({0, 0});
12011208
}
12021209
bases[kRow].push_back({16, 0});
12031210
bases[kRow].push_back({32, 0});
12041211
tile = LinearLayout(bases, dims);
12051212
} else {
1206-
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) *
1207-
LinearLayout::identity1D(blockN, kCol, dims[1]);
1213+
tile *= LinearLayout::identity1D(blockM, kRow, dims[0]) *
1214+
LinearLayout::identity1D(blockN, kCol, dims[1]);
12081215
}
12091216
auto repsM = shape[0] / tile.getOutDimSize(dims[0]);
12101217
auto repsN = shape[1] / tile.getOutDimSize(dims[1]);
@@ -1223,14 +1230,18 @@ tensorMemoryScalesToLinearLayout(ArrayRef<int64_t> shape,
12231230
auto kRow = S("row");
12241231
auto kCol = S("col");
12251232
auto dims = standardOutDimNames(ctx, 2);
1226-
// nb. this can be done with
1227-
// ensureLayoutNotSmallerThan/ensureLayoutNotLargerThan but it's a bit less
1228-
// clear IMO
1233+
// See [Zeros in TMEM LinearLayouts]
12291234
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
12301235
// We choose repOrder = [0, 1]
12311236
auto tile =
12321237
LinearLayout::identity1D(std::min<int>(32, shape[0]), kRow, dims[0]) *
1238+
// If shape[0] < 32, we have some rows undefined
1239+
LinearLayout::zeros1D(32 / std::min<int>(32, shape[0]), kRow, dims[0]) *
1240+
// Broadcasting
1241+
LinearLayout::zeros1D(4, kRow, dims[0]) *
12331242
LinearLayout::identity1D(std::min<int>(4, shape[1]), kCol, dims[1]) *
1243+
// If shape[1] < 4, we have some cols undefined
1244+
LinearLayout::zeros1D(4 / std::min<int>(4, shape[1]), kCol, dims[1]) *
12341245
// reps
12351246
LinearLayout::identity1D(std::max<int>(1, shape[0] / 32), kCol, dims[0]) *
12361247
LinearLayout::identity1D(std::max<int>(1, shape[1] / 4), kCol, dims[1]);

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct AutomaticWarpSpecialization
3535
void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
38-
pm.addPass(createNVWSInsertAref());
38+
// TODO: re-enable once the regression is fixed.
39+
// pm.addPass(createNVWSInsertAref());
3940
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
4041
pm.addPass(createTritonGPURewritePartitionDependencies());
4142
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.

python/triton_kernels/tests/test_matmul.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ class Case:
159159
split_k: int = 1
160160
hbm_swizzling: bool = False
161161
epilogue_subtile: Union[int, None] = None
162+
x_transpose: bool = False
163+
w_transpose: bool = False
164+
y_transpose: bool = False
162165

163166

164167
@pytest.mark.parametrize(
@@ -252,6 +255,13 @@ class Case:
252255
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
253256
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
254257
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
258+
] + [
259+
Case(320, 400, 400, mode, dtype, dtype, x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose)
260+
for mode in ("batched", "ragged")
261+
for dtype in ("float16", "float8_e5m2")
262+
for x_transpose in (False, True)
263+
for w_transpose in (False, True)
264+
for y_transpose in (False, True)
255265
]
256266
],
257267
)
@@ -268,6 +278,7 @@ class Case:
268278
@pytest.mark.parametrize("is_persistent", [False, True])
269279
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
270280
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
281+
x_transpose, w_transpose, y_transpose,
271282
device, opt_flags_scope, fresh_knobs):
272283
# TODO: remove when Triton FP8 supports proper RTNE
273284
if is_cuda():
@@ -373,6 +384,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
373384
has_y_gammas, requires_grad=test_bwd, device=device)
374385
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
375386

387+
if x_transpose:
388+
x_tri = x_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd)
389+
if w_transpose:
390+
w_tri = w_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd)
391+
if y_transpose:
392+
n_rows = m if gindx is None else gindx.dst_indx.shape[0]
393+
yT_shape = (n_expts_tot, n, n_rows) if mode == "batched" else (n, n_rows)
394+
y_tri_in = torch.empty(yT_shape, dtype=act_dtype, device=device).transpose(-1, -2)
395+
else:
396+
y_tri_in = None
397+
376398
if w_tri.shape[0] == 1 and mode != "batched":
377399
# Test the case when weight has dim 2, i.e., shape (K, N).
378400
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
@@ -423,9 +445,14 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
423445

424446
# triton
425447
try:
426-
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue)
448+
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt,
449+
gammas=gs1_ref, epilogue=epilogue, y=y_tri_in)
427450
except (opt_flags.InapplicableConstraint, NotImplementedError):
428451
pytest.xfail("inapplicable opt_flags constraint")
452+
if y_tri_in is not None:
453+
assert tri_y.data_ptr() == y_tri_in.data_ptr()
454+
assert tri_y.shape == y_tri_in.shape
455+
assert tri_y.stride() == y_tri_in.stride()
429456
# If split_k > 1, then the intermediate tensor is fp32.
430457
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
431458
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
@@ -537,7 +564,7 @@ def test_set_idle_sms():
537564
num_idle_sms = 24
538565
matmul_ogs_set_idle_sms(num_idle_sms)
539566
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
540-
1, 1024, 1024, 1024, None, True, False, 1)
567+
1, 1024, 1024, 1024, None, True, False, 1, False)
541568
assert flags.idle_sms == num_idle_sms
542569

543570

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def apply_allocation(allocation: MatmulAllocation, output):
177177
if output is None:
178178
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
179179
else:
180+
if output.ndim == 2:
181+
output = output[None, :, :]
180182
assert output.shape == allocation.output[0]
181183
ret["output"] = output[None, :, :]
182184
ret["scratchpad"] = {
@@ -350,6 +352,7 @@ def matmul_ogs(x, w, bias,
350352
x_scale = Tensor(x_scale)
351353
if not isinstance(x, Tensor):
352354
x = Tensor(x, dtype=x.dtype)
355+
x_transpose = x.stride(-1) != 1
353356
# determine shapes
354357
has_gather = gather_indx is not None
355358
has_scatter = scatter_indx is not None
@@ -362,14 +365,20 @@ def matmul_ogs(x, w, bias,
362365
assert x.shape[0] == w.shape[0]
363366
# compute optimization flags
364367
out_dtype = precision_config.out_dtype or x.dtype
365-
can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
366-
w.numel() > 0 and w.storage.is_tma_compliant() and \
367-
(w_scale is None or w_scale.storage.is_tma_compliant())
368+
can_use_tma = (
369+
x.numel() > 0 and x.storage.is_tma_compliant() and
370+
w.numel() > 0 and w.storage.is_tma_compliant() and
371+
(w_scale is None or w_scale.storage.is_tma_compliant()) and
372+
(not is_ragged or x.stride(-1) == 1) and
373+
# Currently we don't support tma if y is column major; may revisit later if this becomes an issue.
374+
(y is None or y.stride(-1) == 1)
375+
)
368376
# hopper w/ mxfp4 doesn't support TMA
369377
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
370378
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
371379
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
372-
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
380+
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter,
381+
epilogue.effective_itemsize, x_transpose,
373382
)
374383
if not can_use_fused_scatter and opt_flags.fused_scatter:
375384
raise InapplicableConstraint("Fused scatter is not supported")
@@ -469,7 +478,7 @@ def matmul_ogs(x, w, bias,
469478
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
470479
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
471480
*out_matmul_scale_strides[-4:],
472-
x_tensor_or_tma, x_storage.data, *x_strides,
481+
x_tensor_or_tma, x_storage.data, *x_strides, x_transpose,
473482
flex.lhs_data.scale,
474483
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
475484
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _matmul_ogs(
3434
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
3535
YExpectedScale, YActualScale, YChecksumScale,
3636
stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
37-
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
37+
X, XPtr, stride_x_z, stride_x_m, stride_x_k, X_TRANSPOSE: tl.constexpr,
3838
XScale,
3939
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
4040
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _p_matmul_ogs(
8282
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
8383
YExpectedScale, YActualScale, YChecksumScale,
8484
stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
85-
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
85+
X, XPtr, stride_x_z, stride_x_m, stride_x_k, X_TRANSPOSE: tl.constexpr,
8686
XScale,
8787
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
8888
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
@@ -282,13 +282,17 @@ def _p_matmul_ogs(
282282
if EVEN_K:
283283
mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
284284
else:
285-
mask_k_scale = offs_k_scale < tl.cdiv(K, MX_PACK_DIVISOR)
285+
mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR)
286286

287287
if USE_GATHER_TMA:
288288
x = X.gather(offs_x_m, off_k)
289289
elif X_TMA_MODE == "dense":
290-
x = X.load([start_z, start_m + off_m, off_k])
291-
x = x.reshape(BLOCK_M, BLOCK_K)
290+
if X_TRANSPOSE:
291+
x = X.load([start_z, off_k, start_m + off_m])
292+
x = x.reshape(BLOCK_K, BLOCK_M).T
293+
else:
294+
x = X.load([start_z, start_m + off_m, off_k])
295+
x = x.reshape(BLOCK_M, BLOCK_K)
292296
elif X_TMA_MODE == "ragged":
293297
x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
294298
x = x.reshape(BLOCK_M, BLOCK_K)

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def make_default_opt_flags_intel(
4444
can_use_fused_scatter,
4545
enforce_bitwise_invariance,
4646
epilogue_effective_itemsize,
47+
x_transpose,
4748
constraints,
4849
):
4950
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
@@ -123,6 +124,7 @@ def make_default_opt_flags_amd(
123124
can_use_fused_scatter,
124125
enforce_bitwise_invariance,
125126
epilogue_effective_itemsize,
127+
x_transpose,
126128
constraints,
127129
):
128130
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
@@ -222,6 +224,7 @@ def make_default_opt_flags_nvidia(
222224
can_use_fused_scatter,
223225
enforce_bitwise_invariance,
224226
epilogue_effective_itemsize,
227+
x_transpose,
225228
constraints,
226229
):
227230
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
@@ -286,6 +289,7 @@ def make_default_opt_flags_nvidia(
286289
out_dtype,
287290
lhs_dtype,
288291
rhs_dtype,
292+
x_transpose,
289293
)
290294

291295
if constraints.get("epilogue_subtile", None) is not None:
@@ -365,6 +369,7 @@ def make_opt_flags(
365369
can_use_persistent_tma,
366370
can_use_fused_scatter,
367371
epilogue_effective_itemsize,
372+
x_transpose,
368373
):
369374
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
370375
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
@@ -376,7 +381,7 @@ def make_opt_flags(
376381
return _opt_flags
377382
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k,
378383
routing_data, can_use_persistent_tma, can_use_fused_scatter,
379-
enforce_bitwise_invariance, epilogue_effective_itemsize,
384+
enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose,
380385
_opt_flags_constraints]
381386
backend = triton.runtime.driver.active.get_current_target().backend
382387
if backend == "xpu":

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def compute_num_stages(
7272
out_dtype,
7373
lhs_dtype,
7474
rhs_dtype,
75+
x_transpose,
7576
epilogue_subtile,
7677
epilogue_effective_itemsize,
7778
):
@@ -103,6 +104,8 @@ def compute_num_stages(
103104
# pipelined layout conversion before store of the accumulator
104105
# note: layout conversion has some padding
105106
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
107+
if x_transpose:
108+
smem_capacity -= block_m * block_k * lhs_dtype.itemsize
106109
if precision_config.weight_scale is not None:
107110
# mx scales
108111
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))

0 commit comments

Comments
 (0)