Skip to content

Commit 96bda72

Browse files
committed
Merge commit '67dc6270aa8905b193d72756d5d54b3dccf9e168'
2 parents 4c4709d + 67dc627 commit 96bda72

File tree

14 files changed

+220
-482
lines changed

14 files changed

+220
-482
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
6868
fn()
6969
end_event.record()
7070
synchronize()
71+
# wait_on_sycl_queue?
7172
estimate_ms = start_event.elapsed_time(end_event) / 5
7273

7374
# The cache is also maintained in `triton_do_bench` function,

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4519,7 +4519,7 @@ def test_value_specialization(value: int, value_type: str, device) -> None:
45194519

45204520
def repr(specialization):
45214521
ty = specialization.signature["value1"]
4522-
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
4522+
cst = '_'.join([k for k, v in specialization.constants.items() if isinstance(k, str) and v == 1])
45234523
return f"kernel_{ty}_{cst}"
45244524

45254525
@triton.jit(repr=repr)

python/test/unit/language/test_tuple.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
8181
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
8282
tl.static_assert(N1 is None)
8383
tl.static_assert(tuple1[1][1] is None)
84+
tl.static_assert(tuple1[1][3] == 4)
8485
tl.store(Ptr + 0, tl.load(tuple1[0]))
8586
tl.store(Ptr + 1, tuple1[1][0])
8687
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
@@ -95,6 +96,6 @@ def test_serialize(device="xpu"):
9596
y0 = torch.tensor([10], dtype=torch.int32, device=device)
9697
z = torch.empty((10, ), dtype=torch.int32, device=device)
9798
# we want to check that JIT specialization propagates to tuples:
98-
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, ))
99+
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1, tl.constexpr(4))), 20, 1, (y0, ))
99100
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
100101
assert torch.equal(z, ref)

python/test/unit/runtime/test_bindings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def walk_fn(op):
6868
constexprs={kernel.arg_names[i]: arg
6969
for i, arg in enumerate(args)
7070
if not isinstance(arg, torch.Tensor)},
71-
attrs=backend.get_attrs_descriptor(kernel.params, args),
7271
)
7372

7473
context = triton._C.libtriton.ir.context()

python/test/unit/runtime/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def kernel_add(a):
563563

564564
def cache_hook(*args, **kwargs):
565565
nonlocal pointer_range_32
566-
pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32
566+
pointer_range_32 = [k for k, v in kwargs["compile"]["configs"][0].items() if ['tt.pointer_range', 32] in v]
567567

568568
JITFunction.cache_hook = cache_hook
569569
# In warmup we assume that the pointer range is 32 bits

python/test/unit/runtime/test_subproc.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
import triton
55
import triton.language as tl
6-
from triton.backends.compiler import AttrsDescriptor
76
from triton.compiler import ASTSource
87

98
target = triton.runtime.driver.active.get_current_target()
109
start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn'
1110

1211

13-
def compile_fn(attrs):
12+
def compile_fn():
1413

1514
@triton.jit
1615
def kernel_sub(a, b, o, N: tl.constexpr):
@@ -21,21 +20,19 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2120
fn=kernel_sub,
2221
constexprs={'N': 32},
2322
signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'},
24-
attrs=attrs,
2523
)
2624
triton.compile(src=src, target=target)
2725

2826

2927
def test_compile_in_subproc() -> None:
30-
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
3128
mp_ctx = multiprocessing.get_context(start_method)
32-
proc = mp_ctx.Process(target=compile_fn, args=(config, ))
29+
proc = mp_ctx.Process(target=compile_fn)
3330
proc.start()
3431
proc.join()
3532
assert proc.exitcode == 0
3633

3734

38-
def compile_fn_dot(attrs):
35+
def compile_fn_dot():
3936

4037
@triton.jit
4138
def kernel_dot(Z):
@@ -44,28 +41,27 @@ def kernel_dot(Z):
4441
z = tl.dot(z, z)
4542
tl.store(Z + offs, z)
4643

47-
src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constexprs={})
44+
src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"})
4845
triton.compile(src=src, target=target)
4946

5047

5148
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
52-
config = AttrsDescriptor.from_hints({0: 16})
5349
mp_ctx = multiprocessing.get_context(start_method)
54-
proc = mp_ctx.Process(target=compile_fn_dot, args=(config, ))
50+
proc = mp_ctx.Process(target=compile_fn_dot)
5551
proc.start()
5652
proc.join()
5753
assert proc.exitcode == 0
5854

5955

60-
def compile_empty_kernel_with_gc(attrs):
56+
def compile_empty_kernel_with_gc():
6157

6258
@triton.jit
6359
def empty_kernel():
6460
pass
6561

6662
import gc
6763
gc.collect()
68-
src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constexprs={})
64+
src = ASTSource(fn=empty_kernel, signature={})
6965
triton.compile(src=src, target=target)
7066

7167

@@ -88,13 +84,12 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8884
gc.disable()
8985

9086
# stage 1.p
91-
config = AttrsDescriptor.from_hints({0: 16})
92-
compile_empty_kernel_with_gc(config)
87+
compile_empty_kernel_with_gc()
9388

9489
# stage 2.p
9590
shutil.rmtree(fresh_triton_cache, ignore_errors=True)
9691
mp_ctx = multiprocessing.get_context(start_method)
97-
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, ))
92+
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc)
9893

9994
# stage 3.c
10095
proc.start()

0 commit comments

Comments
 (0)