33
44import triton
55import triton .language as tl
6- from triton .backends .compiler import AttrsDescriptor
76from triton .compiler import ASTSource
87
98target = triton .runtime .driver .active .get_current_target ()
109start_method = 'fork' if 'fork' in multiprocessing .get_all_start_methods () else 'spawn'
1110
1211
13- def compile_fn (attrs ):
12+ def compile_fn ():
1413
1514 @triton .jit
1615 def kernel_sub (a , b , o , N : tl .constexpr ):
@@ -21,21 +20,19 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2120 fn = kernel_sub ,
2221 constexprs = {'N' : 32 },
2322 signature = {'a' : "*fp32" , 'b' : "*fp32" , 'o' : "*fp32" , 'N' : 'constexpr' },
24- attrs = attrs ,
2523 )
2624 triton .compile (src = src , target = target )
2725
2826
2927def test_compile_in_subproc () -> None :
30- config = AttrsDescriptor .from_hints ({i : 16 for i in range (4 )})
3128 mp_ctx = multiprocessing .get_context (start_method )
32- proc = mp_ctx .Process (target = compile_fn , args = ( config , ) )
29+ proc = mp_ctx .Process (target = compile_fn )
3330 proc .start ()
3431 proc .join ()
3532 assert proc .exitcode == 0
3633
3734
38- def compile_fn_dot (attrs ):
35+ def compile_fn_dot ():
3936
4037 @triton .jit
4138 def kernel_dot (Z ):
@@ -44,28 +41,27 @@ def kernel_dot(Z):
4441 z = tl .dot (z , z )
4542 tl .store (Z + offs , z )
4643
47- src = ASTSource (fn = kernel_dot , signature = {'Z' : "*fp32" }, attrs = attrs , constexprs = {} )
44+ src = ASTSource (fn = kernel_dot , signature = {'Z' : "*fp32" })
4845 triton .compile (src = src , target = target )
4946
5047
5148def test_compile_in_forked_subproc (fresh_triton_cache ) -> None :
52- config = AttrsDescriptor .from_hints ({0 : 16 })
5349 mp_ctx = multiprocessing .get_context (start_method )
54- proc = mp_ctx .Process (target = compile_fn_dot , args = ( config , ) )
50+ proc = mp_ctx .Process (target = compile_fn_dot )
5551 proc .start ()
5652 proc .join ()
5753 assert proc .exitcode == 0
5854
5955
60- def compile_empty_kernel_with_gc (attrs ):
56+ def compile_empty_kernel_with_gc ():
6157
6258 @triton .jit
6359 def empty_kernel ():
6460 pass
6561
6662 import gc
6763 gc .collect ()
68- src = ASTSource (fn = empty_kernel , signature = {}, attrs = attrs , constexprs = {} )
64+ src = ASTSource (fn = empty_kernel , signature = {})
6965 triton .compile (src = src , target = target )
7066
7167
@@ -88,13 +84,12 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8884 gc .disable ()
8985
9086 # stage 1.p
91- config = AttrsDescriptor .from_hints ({0 : 16 })
92- compile_empty_kernel_with_gc (config )
87+ compile_empty_kernel_with_gc ()
9388
9489 # stage 2.p
9590 shutil .rmtree (fresh_triton_cache , ignore_errors = True )
9691 mp_ctx = multiprocessing .get_context (start_method )
97- proc = mp_ctx .Process (target = compile_empty_kernel_with_gc , args = ( config , ) )
92+ proc = mp_ctx .Process (target = compile_empty_kernel_with_gc )
9893
9994 # stage 3.c
10095 proc .start ()
0 commit comments