Skip to content

Commit fa3b178

Browse files
committed
Remove runtime=nvidia and enable interactive mode (docker run -it ...)
1 parent 83e2c95 commit fa3b178

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

tests/pytorch/inference/test_huggingface_inference_toolkit.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,6 @@ def test_transformers(
5858

5959
client = docker.from_env()
6060

61-
cuda_kwargs = {}
62-
if CUDA_AVAILABLE:
63-
cuda_kwargs = {
64-
"runtime": "nvidia",
65-
"device_requests": [DeviceRequest(count=-1, capabilities=[["gpu"]])],
66-
}
67-
6861
logging.info(f"Starting container for {hf_model_id}...")
6962
container = client.containers.run(
7063
os.getenv(
@@ -91,8 +84,13 @@ def test_transformers(
9184
},
9285
platform="linux/amd64",
9386
detach=True,
94-
# Extra kwargs related to the CUDA devices
95-
**cuda_kwargs,
87+
# Enable interactive mode
88+
tty=True,
89+
stdin_open=True,
90+
# Extra `device_requests` related to the CUDA devices if any
91+
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])]
92+
if CUDA_AVAILABLE
93+
else None,
9694
)
9795

9896
# Start log streaming in a separate thread

tests/pytorch/training/test_trl.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ def test_trl(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None:
5757
},
5858
platform="linux/amd64",
5959
detach=True,
60+
# Enable interactive mode
61+
tty=True,
62+
stdin_open=True,
6063
# Mount the volume from the `tmp_path` to the `/opt/huggingface/trained_model`
6164
volumes={
6265
f"{tmp_path}/": {
6366
"bind": "/opt/huggingface/trained_model",
6467
"mode": "rw",
6568
}
6669
},
67-
# Extra kwargs related to the CUDA devices
68-
runtime="nvidia",
70+
# Extra `device_requests` related to the CUDA devices
6971
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
7072
)
7173

@@ -131,15 +133,17 @@ def test_trl_peft(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None
131133
},
132134
platform="linux/amd64",
133135
detach=True,
136+
# Enable interactive mode
137+
tty=True,
138+
stdin_open=True,
134139
# Mount the volume from the `tmp_path` to the `/opt/huggingface/trained_model`
135140
volumes={
136141
f"{tmp_path}/": {
137142
"bind": "/opt/huggingface/trained_model",
138143
"mode": "rw",
139144
}
140145
},
141-
# Extra kwargs related to the CUDA devices
142-
runtime="nvidia",
146+
# Extra `device_requests` related to the CUDA devices
143147
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
144148
)
145149

tests/tei/test_tei.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,6 @@ def test_text_embeddings_inference(
3535

3636
client = docker.from_env()
3737

38-
cuda_kwargs = {}
39-
if CUDA_AVAILABLE:
40-
cuda_kwargs = {
41-
"runtime": "nvidia",
42-
"device_requests": [DeviceRequest(count=-1, capabilities=[["gpu"]])],
43-
}
44-
4538
logging.info(
4639
f"Starting container for {text_embeddings_router_kwargs.get('MODEL_ID', None)}..."
4740
)
@@ -66,8 +59,13 @@ def test_text_embeddings_inference(
6659
},
6760
platform="linux/amd64",
6861
detach=True,
69-
# Extra kwargs related to the CUDA devices
70-
**cuda_kwargs,
62+
# Enable interactive mode
63+
tty=True,
64+
stdin_open=True,
65+
# Extra `device_requests` related to the CUDA devices if any
66+
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])]
67+
if CUDA_AVAILABLE
68+
else None,
7169
)
7270
logging.info(f"Container {container.id} started...") # type: ignore
7371

tests/tgi/test_tgi.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,19 @@ def test_text_generation_inference(
5353
),
5454
ports={8080: 8080},
5555
environment=text_generation_launcher_kwargs,
56-
# healthcheck={
57-
# "test": ["CMD", "curl", "-s", "http://localhost:8080/health"],
58-
# "interval": int(30 * 1e9),
59-
# "timeout": int(30 * 1e9),
60-
# "retries": 3,
61-
# "start_period": int(30 * 1e9),
62-
# },
63-
# platform="linux/amd64",
56+
healthcheck={
57+
"test": ["CMD", "curl", "-s", "http://localhost:8080/health"],
58+
"interval": int(30 * 1e9),
59+
"timeout": int(30 * 1e9),
60+
"retries": 3,
61+
"start_period": int(30 * 1e9),
62+
},
63+
platform="linux/amd64",
6464
detach=True,
65+
# Enable interactive mode
66+
tty=True,
67+
stdin_open=True,
6568
# Extra kwargs related to the CUDA devices
66-
runtime="nvidia",
6769
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
6870
)
6971
logging.info(f"Container {container.id} started...") # type: ignore

0 commit comments

Comments
 (0)