Skip to content

Commit 9f93976

Browse files
authored
Don't use fork method if it's not available on the platform (#5051)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent e82dfd9 commit 9f93976

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

python/test/unit/runtime/test_subproc.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from triton.compiler import ASTSource
88

99
target = triton.runtime.driver.active.get_current_target()
10+
start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn'
1011

1112

1213
def compile_fn(attrs):
@@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2728

2829
def test_compile_in_subproc() -> None:
2930
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
30-
multiprocessing.set_start_method('fork')
31-
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
31+
mp_ctx = multiprocessing.get_context(start_method)
32+
proc = mp_ctx.Process(target=compile_fn, args=(config, ))
3233
proc.start()
3334
proc.join()
3435
assert proc.exitcode == 0
@@ -49,8 +50,8 @@ def kernel_dot(Z):
4950

5051
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
5152
config = AttrsDescriptor.from_hints({0: 16})
52-
assert multiprocessing.get_start_method() == 'fork'
53-
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
53+
mp_ctx = multiprocessing.get_context(start_method)
54+
proc = mp_ctx.Process(target=compile_fn_dot, args=(config, ))
5455
proc.start()
5556
proc.join()
5657
assert proc.exitcode == 0
@@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
9293

9394
# stage 2.p
9495
shutil.rmtree(fresh_triton_cache)
95-
assert multiprocessing.get_start_method() == 'fork'
96-
proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, ))
96+
mp_ctx = multiprocessing.get_context(start_method)
97+
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, ))
9798

9899
# stage 3.c
99100
proc.start()

0 commit comments

Comments
 (0)