77from triton .compiler import ASTSource
88
99target = triton .runtime .driver .active .get_current_target ()
10+ start_method = 'fork' if 'fork' in multiprocessing .get_all_start_methods () else 'spawn'
1011
1112
1213def compile_fn (attrs ):
@@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2728
2829def test_compile_in_subproc () -> None :
2930 config = AttrsDescriptor .from_hints ({i : 16 for i in range (4 )})
30- multiprocessing .set_start_method ( 'fork' )
31- proc = multiprocessing .Process (target = compile_fn , args = (config , ))
31+ mp_ctx = multiprocessing .get_context ( start_method )
32+ proc = mp_ctx .Process (target = compile_fn , args = (config , ))
3233 proc .start ()
3334 proc .join ()
3435 assert proc .exitcode == 0
@@ -49,8 +50,8 @@ def kernel_dot(Z):
4950
5051def test_compile_in_forked_subproc (fresh_triton_cache ) -> None :
5152 config = AttrsDescriptor .from_hints ({0 : 16 })
52- assert multiprocessing .get_start_method () == 'fork'
53- proc = multiprocessing .Process (target = compile_fn_dot , args = (config , ))
53+ mp_ctx = multiprocessing .get_context ( start_method )
54+ proc = mp_ctx .Process (target = compile_fn_dot , args = (config , ))
5455 proc .start ()
5556 proc .join ()
5657 assert proc .exitcode == 0
@@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
9293
9394 # stage 2.p
9495 shutil .rmtree (fresh_triton_cache )
95- assert multiprocessing .get_start_method () == 'fork'
96- proc = multiprocessing .Process (target = compile_empty_kernel_with_gc , args = (config , ))
96+ mp_ctx = multiprocessing .get_context ( start_method )
97+ proc = mp_ctx .Process (target = compile_empty_kernel_with_gc , args = (config , ))
9798
9899 # stage 3.c
99100 proc .start ()
0 commit comments