Skip to content

Commit 2e77b72

Browse files
authored
Merge pull request #390 from ExaWorks/proper_singleton_threads
Proper singleton threads
2 parents b2e42cb + fe17112 commit 2e77b72

File tree

9 files changed

+182
-14
lines changed

9 files changed

+182
-14
lines changed

requirements-tests.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55

66
pytest >= 6.2.0
77
requests >= 2.25.1
8-
pytest-cov
8+
pytest-cov
9+
pytest-timeout

src/psij/executors/local.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
import time
88
from abc import ABC, abstractmethod
99
from types import FrameType
10-
from typing import Optional, Dict, List, Type, Tuple
10+
from typing import Optional, Dict, List, Tuple, Type, cast
1111

1212
import psutil
1313

1414
from psij import InvalidJobException, SubmitException, Launcher
1515
from psij import Job, JobSpec, JobExecutorConfig, JobState, JobStatus
1616
from psij import JobExecutor
17+
from psij.utils import SingletonThread
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -22,7 +23,14 @@ def _handle_sigchld(signum: int, frame: Optional[FrameType]) -> None:
2223
_ProcessReaper.get_instance()._handle_sigchld()
2324

2425

25-
_REAPER_SLEEP_TIME = 0.2
26+
if threading.current_thread() != threading.main_thread():
27+
logger.warning('The psij module is being imported from a non-main thread. This prevents the'
28+
'use of signals in the local executor, which will slow things down a bit.')
29+
else:
30+
signal.signal(signal.SIGCHLD, _handle_sigchld)
31+
32+
33+
_REAPER_SLEEP_TIME = 0.1
2634

2735

2836
class _ProcessEntry(ABC):
@@ -110,18 +118,11 @@ def _get_env(spec: JobSpec) -> Optional[Dict[str, str]]:
110118
return spec.environment
111119

112120

113-
class _ProcessReaper(threading.Thread):
114-
_instance: Optional['_ProcessReaper'] = None
115-
_lock = threading.RLock()
121+
class _ProcessReaper(SingletonThread):
116122

117123
@classmethod
118124
def get_instance(cls: Type['_ProcessReaper']) -> '_ProcessReaper':
119-
with cls._lock:
120-
if cls._instance is None:
121-
cls._instance = _ProcessReaper()
122-
cls._instance.start()
123-
signal.signal(signal.SIGCHLD, _handle_sigchld)
124-
return cls._instance
125+
return cast('_ProcessReaper', super().get_instance())
125126

126127
def __init__(self) -> None:
127128
super().__init__(name='Local Executor Process Reaper', daemon=True)
@@ -198,8 +199,15 @@ class LocalJobExecutor(JobExecutor):
198199
This job executor is intended to be used when there is no resource manager, only
199200
the operating system. Or when there is a resource manager, but it should be ignored.
200201
201-
Limitations: in Linux, attached jobs always appear to complete with a zero exit code regardless
202+
Limitations:
203+
- In Linux, attached jobs always appear to complete with a zero exit code regardless
202204
of the actual exit code.
205+
- Instantiation of a local executor from both parent process and a `fork()`-ed process
206+
is not guaranteed to work. In general, using `fork()` and multi-threading in Linux is unsafe,
207+
as suggested by the `fork()` man page. While PSI/J attempts to minimize problems that can
208+
arise when `fork()` is combined with threads (which are used by PSI/J), no guarantees can be
209+
made and the chances of unexpected behavior are high. Please do not use PSI/J with `fork()`.
210+
If you do, please be mindful that support for using PSI/J with `fork()` will be limited.
203211
"""
204212

205213
def __init__(self, url: Optional[str] = None,

src/psij/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import os
2+
import threading
13
from pathlib import Path
2-
from typing import Optional
4+
from typing import Optional, Type, Dict
35
import sys
46

57

@@ -19,3 +21,32 @@ def path_object_to_full_path(obj: Optional[object]) -> Optional[str]:
1921
sys.exit("This type " + type(obj).__name__
2022
+ " for a path is not supported, use pathlib instead")
2123
return p
24+
25+
26+
class SingletonThread(threading.Thread):
27+
"""
28+
A convenience class to return a thread that is guaranteed to be unique to this process.
29+
30+
This is intended to work with fork() to ensure that each os.getpid() value is associated with
31+
at most one thread. This is not safe. The safe thing, as pointed out by the fork() man page,
32+
is to not use fork() with threads. However, this is here in an attempt to make it slightly
33+
safer for when users really really want to take the risk against all advice.
34+
"""
35+
36+
_instances: Dict[int, 'SingletonThread'] = {}
37+
_lock = threading.RLock()
38+
39+
@classmethod
40+
def get_instance(cls: Type['SingletonThread']) -> 'SingletonThread':
41+
"""Returns a started instance of this thread.
42+
43+
The instance is guaranteed to be unique for this process. This method also guarantees
44+
that a forked process will get a separate instance of this thread from the parent.
45+
"""
46+
with cls._lock:
47+
my_pid = os.getpid()
48+
if my_pid not in cls._instances:
49+
instance = cls()
50+
cls._instances[my_pid] = instance
51+
instance.start()
52+
return cls._instances[my_pid]

tests/test_issue_387_1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
import subprocess
3+
import sys
4+
5+
import pytest
6+
7+
8+
@pytest.mark.timeout(5)
9+
def test_issue_387_1() -> None:
10+
subprocess.run([sys.executable, os.path.abspath(__file__)[:-2] + '.run'],
11+
shell=True, check=True, capture_output=True)

tests/test_issue_387_1.run

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# This test must be run in a separate process, since it is sensitive
2+
# to threading and process context.
3+
4+
import logging
5+
import os
6+
from multiprocessing import Process, set_start_method
7+
from threading import Thread
8+
9+
import psij
10+
11+
12+
def func():
13+
# Get logs from PSI/J
14+
logger = logging.getLogger()
15+
logger.setLevel("DEBUG")
16+
logger.addHandler(logging.StreamHandler())
17+
18+
exe = psij.JobExecutor.get_instance("local")
19+
job = psij.Job(psij.JobSpec("test", "echo", arguments=["foo"]))
20+
exe.submit(job)
21+
print(job, flush=True)
22+
23+
# This hangs on the second round
24+
job.wait()
25+
print(job, flush=True)
26+
27+
28+
def fn():
29+
print(os.getpid())
30+
31+
32+
if __name__ == "__main__":
33+
Thread(target=fn).start()
34+
set_start_method("fork")
35+
p = Process(target=func)
36+
p.start()
37+
p.join()
38+
39+
print("Done, and again...")
40+
41+
p = Process(target=func)
42+
p.start()
43+
p.join()

tests/test_issue_387_2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
import subprocess
3+
import sys
4+
5+
import pytest
6+
7+
8+
@pytest.mark.timeout(5)
9+
def test_issue_387_2() -> None:
10+
subprocess.run([sys.executable, os.path.abspath(__file__)[:-2] + '.run'],
11+
shell=True, check=True, capture_output=True)

tests/test_issue_387_2.run

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This test must be run in a separate process, since it is sensitive
2+
# to threading and process context.
3+
4+
import logging
5+
from threading import Thread
6+
import psij
7+
8+
9+
def func():
10+
# Get logs from PSI/J
11+
logger = logging.getLogger()
12+
logger.setLevel("DEBUG")
13+
logger.addHandler(logging.StreamHandler())
14+
15+
exe = psij.JobExecutor.get_instance("local")
16+
job = psij.Job(psij.JobSpec("test", "echo", arguments=["foo"]))
17+
exe.submit(job)
18+
19+
# This hangs
20+
job.wait()
21+
22+
if __name__ == "__main__":
23+
p = Thread(target=func)
24+
p.start()
25+
p.join()
26+

tests/test_issue_387_3.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
import subprocess
3+
import sys
4+
5+
import pytest
6+
7+
8+
@pytest.mark.timeout(5)
9+
def test_issue_387_3() -> None:
10+
subprocess.run([sys.executable, os.path.abspath(__file__)[:-2] + '.run'],
11+
shell=True, check=True, capture_output=True)

tests/test_issue_387_3.run

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This test must be run in a separate process, since it is sensitive
2+
# to threading and process context.
3+
4+
import logging
5+
from threading import Thread
6+
import psij
7+
8+
9+
def func():
10+
# Get logs from PSI/J
11+
logger = logging.getLogger()
12+
logger.setLevel("DEBUG")
13+
logger.addHandler(logging.StreamHandler())
14+
15+
exe = psij.JobExecutor.get_instance("local")
16+
job = psij.Job(psij.JobSpec("test", "echo", arguments=["foo"]))
17+
exe.submit(job)
18+
19+
# This hangs
20+
job.wait()
21+
22+
if __name__ == "__main__":
23+
p = Thread(target=func)
24+
p.start()
25+
p.join()
26+

0 commit comments

Comments
 (0)