Skip to content

Commit 56f2c48

Browse files
Chris ElionErvin T
andauthored
[bug-fix] Set number of threads based on allocated CPU count in Docker containers (#4471) (#4478)
* Set num threads properly for Docker * Pylint-friendly logic * Use f.read().rstrip() * Change function names Co-authored-by: Ervin T <[email protected]>
1 parent f493e4a commit 56f2c48

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Optional
2+
3+
import os
4+
5+
6+
def get_num_threads_to_use() -> Optional[int]:
7+
"""
8+
Gets the number of threads to use. For most problems, 4 is all you
9+
need, but for smaller machines, we'd like to scale to less than that.
10+
By default, PyTorch uses 1/2 of the available cores.
11+
"""
12+
num_cpus = _get_num_available_cpus()
13+
return max(min(num_cpus // 2, 4), 1) if num_cpus is not None else None
14+
15+
16+
def _get_num_available_cpus() -> Optional[int]:
17+
"""
18+
Returns number of CPUs using cgroups if possible. This accounts
19+
for Docker containers that are limited in cores.
20+
"""
21+
period = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_period_us")
22+
quota = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")
23+
if period > 0 and quota > 0:
24+
return int(quota // period)
25+
else:
26+
return os.cpu_count()
27+
28+
29+
def _read_in_integer_file(filename: str) -> int:
30+
try:
31+
with open(filename) as f:
32+
return int(f.read().rstrip())
33+
except FileNotFoundError:
34+
return -1

ml-agents/mlagents/torch_utils/torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import os
22

3+
from mlagents.torch_utils import cpu_utils
4+
35
# Detect availability of torch package here.
46
# NOTE: this try/except is temporary until torch is required for ML-Agents.
57
try:
68
# This should be the only place that we import torch directly.
79
# Everywhere else is caught by the banned-modules setting for flake8
810
import torch # noqa I201
911

10-
torch.set_num_interop_threads(2)
12+
torch.set_num_threads(cpu_utils.get_num_threads_to_use())
1113
os.environ["KMP_BLOCKTIME"] = "0"
1214

1315
# Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701

0 commit comments

Comments
 (0)