Skip to content

Commit 78b2ef8

Browse files
committed
Add GPU passthrough for Docker sandboxes
1 parent 55cae5c commit 78b2ef8

File tree

5 files changed

+183
-0
lines changed

5 files changed

+183
-0
lines changed

rock-conf/rock-local.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,9 @@ ray:
77
warmup:
88
images:
99
- "python:3.11"
10+
11+
runtime:
12+
enable_gpu_passthrough: true
13+
gpu_allocation_mode: "round_robin"
14+
gpu_count_per_sandbox: 1
15+
gpu_device_request: "all"

rock/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ class RuntimeConfig:
168168
use_standard_spec_only: bool = False
169169
metrics_endpoint: str = ""
170170
user_defined_tags: dict = field(default_factory=dict)
171+
enable_gpu_passthrough: bool = False
172+
gpu_device_request: str = "all"
173+
gpu_allocation_mode: str = "fixed"
174+
gpu_count_per_sandbox: int = 1
171175

172176
def __post_init__(self) -> None:
173177
# Convert dict to StandardSpec if needed

rock/deployments/docker.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import datetime
3+
import fcntl
34
import os
45
import random
56
import shlex
@@ -64,6 +65,7 @@ def __init__(
6465
self._stop_time = datetime.datetime.now() + datetime.timedelta(minutes=self._config.auto_clear_time)
6566
self._check_stop_task = None
6667
self._container_name = None
68+
self._resolved_gpu_spec: str | None = None
6769
self._service_status = PersistedServiceStatus()
6870
if self._config.container_name:
6971
self.set_container_name(self._config.container_name)
@@ -169,6 +171,96 @@ def _build_runtime_args(self) -> list[str]:
169171
]
170172
return ["--privileged"]
171173

174+
def _detect_gpu_count(self) -> int:
175+
"""Detect the number of GPUs visible on the Docker host."""
176+
try:
177+
out = subprocess.check_output(
178+
["nvidia-smi", "--list-gpus"],
179+
text=True,
180+
stderr=subprocess.DEVNULL,
181+
)
182+
return len([line for line in out.splitlines() if line.strip()])
183+
except Exception:
184+
return 0
185+
186+
def _resolve_round_robin_gpu_spec(self, gpu_count_per_sandbox: int) -> str | None:
187+
"""Allocate device ids in round-robin across host GPUs."""
188+
total_gpus = self._detect_gpu_count()
189+
if total_gpus <= 0:
190+
logger.warning("GPU round-robin requested but no GPUs detected on host")
191+
return None
192+
193+
per_sandbox = max(1, min(int(gpu_count_per_sandbox), total_gpus))
194+
counter_path = os.getenv("ROCK_GPU_COUNTER_PATH", "/tmp/rock_gpu_rr_counter")
195+
os.makedirs(os.path.dirname(counter_path) or ".", exist_ok=True)
196+
197+
with open(counter_path, "a+", encoding="utf-8") as fp:
198+
fcntl.flock(fp.fileno(), fcntl.LOCK_EX)
199+
try:
200+
fp.seek(0)
201+
raw = fp.read().strip()
202+
counter = int(raw) if raw.isdigit() else 0
203+
start = counter % total_gpus
204+
next_counter = counter + per_sandbox
205+
fp.seek(0)
206+
fp.truncate()
207+
fp.write(str(next_counter))
208+
fp.flush()
209+
finally:
210+
fcntl.flock(fp.fileno(), fcntl.LOCK_UN)
211+
212+
device_ids = [(start + i) % total_gpus for i in range(per_sandbox)]
213+
return "device=" + ",".join(str(i) for i in device_ids)
214+
215+
def _build_gpu_args(self) -> list[str]:
216+
"""Build GPU-related docker args from runtime config and ROCK_* env vars."""
217+
self._resolved_gpu_spec = None
218+
if any(arg == "--gpus" or arg.startswith("--gpus=") for arg in self._config.docker_args):
219+
return []
220+
221+
runtime_enabled = bool(getattr(self._config.runtime_config, "enable_gpu_passthrough", False))
222+
env_enabled = os.getenv("ROCK_ENABLE_GPU_PASSTHROUGH", "").strip().lower() in {"1", "true", "yes", "on"}
223+
if not (runtime_enabled or env_enabled):
224+
return []
225+
226+
runtime_mode = str(getattr(self._config.runtime_config, "gpu_allocation_mode", "")).strip().lower()
227+
mode = runtime_mode or os.getenv("ROCK_GPU_ALLOCATION_MODE", "fixed").strip().lower() or "fixed"
228+
229+
gpu_spec: str | None
230+
if mode == "round_robin":
231+
runtime_count = int(getattr(self._config.runtime_config, "gpu_count_per_sandbox", 1) or 1)
232+
env_count_raw = os.getenv("ROCK_GPU_COUNT_PER_SANDBOX", "").strip()
233+
env_count = int(env_count_raw) if env_count_raw.isdigit() else None
234+
per_sandbox = env_count or runtime_count
235+
gpu_spec = self._resolve_round_robin_gpu_spec(per_sandbox)
236+
if not gpu_spec:
237+
return []
238+
logger.info(f"GPU pass-through round-robin enabled: --gpus {gpu_spec}")
239+
else:
240+
runtime_gpu_spec = str(getattr(self._config.runtime_config, "gpu_device_request", "")).strip()
241+
gpu_spec = runtime_gpu_spec or (os.getenv("ROCK_GPU_DEVICE_REQUEST", "all").strip() or "all")
242+
logger.info(f"GPU pass-through fixed mode enabled: --gpus {gpu_spec}")
243+
244+
self._resolved_gpu_spec = gpu_spec
245+
return ["--gpus", gpu_spec]
246+
247+
def _build_gpu_env_args(self) -> list[str]:
248+
"""Inject visibility env vars for deterministic GPU assignment."""
249+
if not self._resolved_gpu_spec:
250+
return []
251+
if self._resolved_gpu_spec == "all":
252+
return []
253+
if self._resolved_gpu_spec.startswith("device="):
254+
devices = self._resolved_gpu_spec.split("=", 1)[1]
255+
if devices:
256+
return [
257+
"-e",
258+
f"NVIDIA_VISIBLE_DEVICES={devices}",
259+
"-e",
260+
f"CUDA_VISIBLE_DEVICES={devices}",
261+
]
262+
return []
263+
172264
def _get_rocklet_start_cmd(self) -> list[str]:
173265
cmd = self._runtime_env.get_rocklet_start_cmd()
174266

@@ -342,15 +434,19 @@ async def start(self):
342434

343435
time.sleep(random.randint(0, 5))
344436
runtime_args = self._build_runtime_args()
437+
gpu_args = self._build_gpu_args()
438+
gpu_env_args = self._build_gpu_env_args()
345439
cmds = [
346440
"docker",
347441
"run",
348442
"--entrypoint",
349443
"",
350444
*env_arg,
445+
*gpu_env_args,
351446
*rm_arg,
352447
*volume_args,
353448
*runtime_args,
449+
*gpu_args,
354450
"-p",
355451
f"{self._config.port}:{Port.PROXY}",
356452
"-p",

tests/unit/rocklet/test_docker_deployment.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
Command,
88
CreateBashSessionRequest,
99
)
10+
from rock.config import RuntimeConfig
1011
from rock.deployments.config import DockerDeploymentConfig, get_deployment
12+
from rock.deployments.docker import DockerDeployment
13+
1114

1215
@pytest.mark.need_docker
1316
async def test_docker_deployment(container_name):
@@ -63,3 +66,58 @@ def test_docker_deployment_config_platform():
6366
config = DockerDeploymentConfig(platform="linux/amd64", docker_args=["--platform", "linux/amd64"])
6467
with pytest.raises(ValueError):
6568
config = DockerDeploymentConfig(platform="linux/amd64", docker_args=["--platform=linux/amd64"])
69+
70+
71+
def test_build_gpu_args_disabled_by_default(monkeypatch):
72+
deployment = DockerDeployment(runtime_config=RuntimeConfig())
73+
74+
monkeypatch.delenv("ROCK_ENABLE_GPU_PASSTHROUGH", raising=False)
75+
76+
assert deployment._build_gpu_args() == []
77+
assert deployment._build_gpu_env_args() == []
78+
79+
80+
def test_build_gpu_args_fixed_mode_from_runtime():
81+
deployment = DockerDeployment(
82+
runtime_config=RuntimeConfig(
83+
enable_gpu_passthrough=True,
84+
gpu_allocation_mode="fixed",
85+
gpu_device_request="device=2",
86+
)
87+
)
88+
89+
assert deployment._build_gpu_args() == ["--gpus", "device=2"]
90+
assert deployment._build_gpu_env_args() == [
91+
"-e",
92+
"NVIDIA_VISIBLE_DEVICES=2",
93+
"-e",
94+
"CUDA_VISIBLE_DEVICES=2",
95+
]
96+
97+
98+
def test_build_gpu_args_skips_when_docker_args_already_set():
99+
deployment = DockerDeployment(
100+
docker_args=["--gpus", "all"],
101+
runtime_config=RuntimeConfig(enable_gpu_passthrough=True),
102+
)
103+
104+
assert deployment._build_gpu_args() == []
105+
106+
107+
def test_build_gpu_args_round_robin(monkeypatch, tmp_path):
108+
deployment = DockerDeployment(
109+
runtime_config=RuntimeConfig(
110+
enable_gpu_passthrough=True,
111+
gpu_allocation_mode="round_robin",
112+
gpu_count_per_sandbox=2,
113+
)
114+
)
115+
116+
monkeypatch.setattr(deployment, "_detect_gpu_count", lambda: 4)
117+
monkeypatch.setenv("ROCK_GPU_COUNTER_PATH", str(tmp_path / "gpu_rr_counter"))
118+
119+
first = deployment._resolve_round_robin_gpu_spec(2)
120+
second = deployment._resolve_round_robin_gpu_spec(2)
121+
122+
assert first == "device=0,1"
123+
assert second == "device=2,3"

tests/unit/test_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,25 @@ async def test_runtime_config():
2222
assert runtime_config.max_allowed_spec.cpus == 16
2323
assert runtime_config.standard_spec.memory == "8g"
2424
assert runtime_config.standard_spec.cpus == 2
25+
assert runtime_config.enable_gpu_passthrough is False
26+
assert runtime_config.gpu_device_request == "all"
27+
assert runtime_config.gpu_allocation_mode == "fixed"
28+
assert runtime_config.gpu_count_per_sandbox == 1
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_runtime_config_gpu_fields():
33+
runtime_config = RuntimeConfig(
34+
enable_gpu_passthrough=True,
35+
gpu_device_request="device=1",
36+
gpu_allocation_mode="round_robin",
37+
gpu_count_per_sandbox=2,
38+
)
39+
40+
assert runtime_config.enable_gpu_passthrough is True
41+
assert runtime_config.gpu_device_request == "device=1"
42+
assert runtime_config.gpu_allocation_mode == "round_robin"
43+
assert runtime_config.gpu_count_per_sandbox == 2
2544

2645
config_full = {
2746
"standard_spec": {

0 commit comments

Comments
 (0)