Skip to content

Commit 6ddacab

Browse files
authored
Merge OpenAI Triton commit 33b2823 (#5091)
This PR change the Triton base from a6d11f7 to 33b2823 (Sep 6). Pass rate: 98.11% --------- Signed-off-by: Anatoly Myachev <[email protected]>
2 parents 719aab3 + 57a1a9b commit 6ddacab

File tree

34 files changed

+1265
-297
lines changed

34 files changed

+1265
-297
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,14 @@ def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
419419
let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis);
420420
let results = (outs TT_FloatTensor:$result);
421421

422+
let extraClassDeclaration = [{
423+
static LogicalResult verifyFp4ToFp(
424+
mlir::Operation *op,
425+
RankedTensorType srcTy,
426+
RankedTensorType resTy,
427+
unsigned axis);
428+
}];
429+
422430
let assemblyFormat = [{
423431
$src attr-dict `:` type($src) `->` type($result)
424432
}];

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -378,36 +378,44 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
378378
LogicalResult Fp4ToFpOp::verify() {
379379
auto srcTy = cast<RankedTensorType>(getSrc().getType());
380380
auto resTy = cast<RankedTensorType>(getResult().getType());
381+
auto axis = getAxis();
382+
383+
auto elemType = resTy.getElementType();
384+
if (!(elemType.isBF16() || elemType.isF16()))
385+
return emitError() << "only bf16 or f16 is supported for now, got "
386+
<< elemType;
387+
388+
return verifyFp4ToFp(*this, srcTy, resTy, axis);
389+
}
390+
391+
LogicalResult Fp4ToFpOp::verifyFp4ToFp(mlir::Operation *op,
392+
RankedTensorType srcTy,
393+
RankedTensorType resTy, unsigned axis) {
381394
auto rank = srcTy.getRank();
382395

383396
if (rank != resTy.getRank())
384-
return emitError() << "source rank " << rank << " != result rank "
385-
<< resTy.getRank();
397+
return op->emitError() << "source rank " << rank << " != result rank "
398+
<< resTy.getRank();
386399

387400
auto srcShape = srcTy.getShape();
388401
auto resShape = resTy.getShape();
389-
auto axis = getAxis();
390402

391403
if (!(0 <= axis && axis < rank))
392-
return emitError() << "axis " << axis << " out of range for rank " << rank;
393-
394-
auto elemType = resTy.getElementType();
395-
if (!(elemType.isBF16() || elemType.isF16()))
396-
return emitError() << "only bf16 or f16 is supported for now, got "
397-
<< elemType;
404+
return op->emitError() << "axis " << axis << " out of range for rank "
405+
<< rank;
398406

399407
for (int i = 0; i < rank; ++i) {
400408
if (i == axis) {
401409
if (resShape[i] != srcShape[i] * 2)
402-
return emitError() << "axis " << axis
403-
<< " dimension must be 2x source dimension (src="
404-
<< srcShape[i] << ", dst=" << resShape[i] << ")";
410+
return op->emitError()
411+
<< "axis " << axis
412+
<< " dimension must be 2x source dimension (src=" << srcShape[i]
413+
<< ", dst=" << resShape[i] << ")";
405414
} else {
406415
if (resShape[i] != srcShape[i])
407-
return emitError() << "dimension " << i
408-
<< " mismatch (src=" << srcShape[i]
409-
<< ", dst=" << resShape[i] << ", axis=" << axis
410-
<< ")";
416+
return op->emitError()
417+
<< "dimension " << i << " mismatch (src=" << srcShape[i]
418+
<< ", dst=" << resShape[i] << ", axis=" << axis << ")";
411419
}
412420
}
413421
return success();

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,16 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
257257
arriveBarrier.getPredMutable().assign(mask);
258258
return op;
259259
}
260+
if (auto commit = dyn_cast<ttng::TCGen5CommitOp>(op)) {
261+
rewriter.setInsertionPoint(commit);
262+
Value mask = pred;
263+
Value currentPred = commit.getPred();
264+
if (currentPred) {
265+
mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred);
266+
}
267+
commit.getPredMutable().assign(mask);
268+
return op;
269+
}
260270
if (auto storeOp = dyn_cast<tt::StoreOp>(op)) {
261271
rewriter.setInsertionPoint(storeOp);
262272
Value mask = getPredMask(rewriter, storeOp.getPtr().getType(),

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ struct AutomaticWarpSpecialization
3535
void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
38+
pm.addPass(createNVWSInsertAref());
3839
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
3940
pm.addPass(createTritonGPURewritePartitionDependencies());
4041
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4142
// FIXME: Re-enable integer range analysis once it is fixed.
4243
// pm.addPass(arith::createIntRangeOptimizationsPass());
4344
pm.addPass(createSCCPPass());
4445
pm.addPass(createCSEPass());
45-
pm.addPass(createNVWSAssignStagePhase());
46-
pm.addPass(createNVWSLowerAref());
46+
pm.addPass(createNVWSLowerAref({numStages}));
4747
pm.addPass(createTritonGPUPartitionLoops());
4848
pm.addPass(createNVWSLowerWarpGroup());
4949
if (failed(runPipeline(pm, getOperation())))

python/test/gluon/test_core.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,3 +811,26 @@ def kernel(N, out):
811811
out = torch.empty(1, dtype=torch.int32, device="cuda")
812812
compiled_kernel = kernel.warmup(N=100, out=out, grid=(1, ))
813813
assert compiled_kernel.asm["llir"].count("define") == 1
814+
815+
816+
@pytest.mark.skipif(not is_hip_cdna3() and not is_hip_cdna4(), reason="Requires CDNA3 or CDNA4")
817+
def test_inline_with_amdgpu_dialect():
818+
819+
@gluon.jit
820+
def buffer_load(x, offsets):
821+
return ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets)
822+
823+
@gluon.jit
824+
def kernel(x, y):
825+
layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[64], warps_per_cta=[4],
826+
order=[0])
827+
offsets = ttgl.arange(0, 64, layout=layout)
828+
829+
a = buffer_load(x, offsets)
830+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets)
831+
832+
input = torch.arange(64, device="cuda").to(torch.int32)
833+
output = torch.empty_like(input)
834+
835+
compiled_kernel = kernel.warmup(input, output, grid=(1, ))
836+
assert compiled_kernel.asm["ttgir"].count("tt.func private") == 0

python/test/unit/runtime/test_autotuner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,32 @@ def grid(meta):
448448
warp_size = triton.runtime.driver.active.get_current_target().warp_size
449449
assert exception_out_of_resource is not None and f"out of resource: threads, Required: {128 * warp_size}" in str(
450450
exception_out_of_resource)
451+
452+
453+
def test_prune_all_configs(device):
454+
N = 1024
455+
src = torch.randn(N, device=device)
456+
dst = torch.empty(N, device=device)
457+
458+
def early_config_prune(configs, named_args, **kwargs):
459+
return []
460+
461+
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
462+
463+
prune_configs_by = {'early_config_prune': early_config_prune}
464+
465+
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by)
466+
@triton.jit
467+
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
468+
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
469+
x = tl.load(src + offsets, mask=offsets < N)
470+
tl.store(dst + offsets, x, mask=offsets < N)
471+
472+
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
473+
try:
474+
_kernel[grid](dst, src, N=N)
475+
pytest.fail("Expected exception was not thrown.")
476+
except triton.TritonError as e:
477+
assert e is not None and str(
478+
e
479+
) == "Autotuner error: No valid autotuner configs after pruning. `early_config_prune` should return at least one config."

python/triton/runtime/autotuner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .. import knobs
1212
from .jit import KernelInterface, JITFunction
13-
from .errors import OutOfResources, PTXASError
13+
from .errors import OutOfResources, PTXASError, AutotunerError
1414
from .driver import driver
1515
from .cache import get_cache_manager, triton_key
1616
from triton._C.libtriton import get_cache_invalidating_env_vars
@@ -25,7 +25,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
2525
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
2626
'perf_model': performance model used to predicate running time with different configs, returns running time
2727
'top_k': number of configs to bench
28-
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
28+
'early_config_prune': a function used to prune configs. It should have the signature
29+
`prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:`
30+
and return pruned configs. It should return at least one config.
2931
"""
3032
if not configs:
3133
self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
@@ -259,6 +261,9 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
259261
pruned_configs = self.configs
260262
if self.early_config_prune:
261263
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
264+
if not pruned_configs:
265+
raise AutotunerError(
266+
"No valid autotuner configs after pruning. `early_config_prune` should return at least one config.")
262267
if self.perf_model:
263268
top_k = self.configs_top_k
264269
if isinstance(top_k, float) and top_k <= 1.0:
@@ -406,7 +411,9 @@ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
406411
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
407412
'perf_model': performance model used to predicate running time with different configs, returns running time
408413
'top_k': number of configs to bench
409-
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
414+
'early_config_prune': a function used to prune configs. It should have the signature
415+
`prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:`
416+
and return pruned configs. It should return at least one config.
410417
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
411418
:type reset_to_zero: list[str]
412419
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.

python/triton/runtime/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,13 @@ def __init__(self, error_message: Optional[str] = None):
3434
def __str__(self) -> str:
3535
error_message = self.error_message or ""
3636
return f"PTXAS error: {error_message}"
37+
38+
39+
class AutotunerError(TritonError):
40+
41+
def __init__(self, error_message: Optional[str] = None):
42+
self.error_message = error_message
43+
44+
def __str__(self) -> str:
45+
error_message = self.error_message or ""
46+
return f"Autotuner error: {error_message}"

python/triton_kernels/tests/test_matmul.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# isort: off
22
# fmt: off
33
from dataclasses import dataclass, fields, replace
4+
import itertools
45
import pytest
56
import torch
67
from typing import Union
@@ -20,7 +21,7 @@
2021
# testing utilities
2122
from triton_kernels.testing import assert_close, compute_actual_scale
2223
# target-specific utilities
23-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
24+
from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
2425

2526
# ---------------
2627
# initialize data
@@ -471,14 +472,68 @@ def round_x(x, idx):
471472
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
472473

473474

475+
# Test that we don't use unsupported block sizes.
476+
@pytest.mark.parametrize("m", [8, 16, 32, 64, 128])
477+
@pytest.mark.parametrize("n", [8, 16, 32, 64, 128])
478+
@pytest.mark.parametrize("k", [8, 16, 32, 64, 128])
479+
def test_small_batch_matmul(m, n, k):
480+
if is_hip():
481+
pytest.skip("Not fully tested on AMD")
482+
if is_xpu():
483+
pytest.xfail("Enable: https://github.com/intel/intel-xpu-backend-for-triton/issues/5092")
484+
485+
if m * n * k > 16384:
486+
pytest.skip()
487+
488+
BATCH_SIZE = 10000
489+
490+
def _make_tensor(shape, dtype, trans):
491+
if trans:
492+
shape = (shape[0], shape[2], shape[1])
493+
t = alloc_rand(shape, "cuda", dtype)
494+
return t.transpose(1, 2) if trans else t
495+
496+
for x_transpose, w_transpose, bias, dtype in itertools.product(
497+
(False, True),
498+
(False, True),
499+
(False, True),
500+
(torch.float16, torch.bfloat16, torch.float8_e5m2),
501+
):
502+
if (
503+
torch.cuda.get_device_capability()[0] < 10
504+
and dtype is torch.float8_e5m2
505+
and (not w_transpose)
506+
):
507+
continue # Not supported
508+
509+
x = _make_tensor((BATCH_SIZE, m, k), dtype, x_transpose)
510+
w = _make_tensor((BATCH_SIZE, k, n), dtype, w_transpose)
511+
bias = _make_tensor((BATCH_SIZE, n), torch.float32, False) if bias else None
512+
tri_y = matmul_ogs(x, w, bias)
513+
514+
# ref_y = matmul_ogs_torch(x.float(), w.float(), bias)
515+
516+
# This is faster than matmul_ogs_torch.
517+
ref_y = torch.bmm(x.float(), w.float())
518+
if bias is not None:
519+
ref_y += bias[:, None, :]
520+
521+
assert_close(
522+
ref_y,
523+
tri_y,
524+
maxtol=4e-1 if dtype is torch.float8_e5m2 else None,
525+
rmstol=4e-2 if dtype is torch.float8_e5m2 else None,
526+
)
527+
528+
474529
def test_set_idle_sms():
475530
if not is_cuda():
476531
pytest.skip("Only supported on CUDA")
477532
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
478533
num_idle_sms = 24
479534
matmul_ogs_set_idle_sms(num_idle_sms)
480535
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
481-
1024, 1024, 1024, None, True, False, 1)
536+
1, 1024, 1024, 1024, None, True, False, 1)
482537
assert flags.idle_sms == num_idle_sms
483538

484539

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def matmul_ogs(x, w, bias,
368368
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
369369
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
370370
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
371-
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
371+
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
372372
)
373373
if not can_use_fused_scatter and opt_flags.fused_scatter:
374374
raise InapplicableConstraint("Fused scatter is not supported")

0 commit comments

Comments
 (0)