Skip to content

Commit 3af2bcf

Browse files
committed
Apply suggestions from code review
- Capture `container_uri` from environment variable before running test and remove the default value to prevent issues when testing - Remove `max_train_epochs=-1` as not required since `max_steps` is already specified - Rename `test_transformers` to `test_huggingface_inference_toolkit`
1 parent 7c4bf87 commit 3af2bcf

File tree

4 files changed

+26
-27
lines changed

4 files changed

+26
-27
lines changed

tests/pytorch/inference/test_huggingface_inference_toolkit.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,23 @@
4848
),
4949
],
5050
)
51-
def test_transformers(
51+
def test_huggingface_inference_toolkit(
5252
caplog: pytest.LogCaptureFixture,
5353
hf_model_id: str,
5454
hf_task: str,
5555
prediction_payload: dict,
5656
) -> None:
5757
caplog.set_level(logging.INFO)
5858

59+
container_uri = os.getenv("INFERENCE_DLC", None)
60+
if container_uri is None or container_uri == "":
61+
assert False, "INFERENCE_DLC environment variable is not set"
62+
5963
client = docker.from_env()
6064

6165
logging.info(f"Starting container for {hf_model_id}...")
6266
container = client.containers.run(
63-
os.getenv(
64-
"INFERENCE_DLC",
65-
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-inference-cpu.2-2.transformers.4-44.ubuntu2204.py311"
66-
if not CUDA_AVAILABLE
67-
else "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-inference-cu121.2-2.transformers.4-44.ubuntu2204.py311",
68-
),
67+
container_uri,
6968
ports={"8080": 8080},
7069
environment={
7170
"HF_MODEL_ID": hf_model_id,

tests/pytorch/training/test_trl.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ def test_trl(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None:
1919
"""Adapted from https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py"""
2020
caplog.set_level(logging.INFO)
2121

22+
container_uri = os.getenv("TRAINING_DLC", None)
23+
if container_uri is None or container_uri == "":
24+
assert False, "TRAINING_DLC environment variable is not set"
25+
2226
client = docker.from_env()
2327

2428
logging.info("Running the container for TRL...")
2529
container = client.containers.run(
26-
os.getenv(
27-
"TRAINING_DLC",
28-
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310",
29-
),
30+
container_uri,
3031
command=[
3132
"trl",
3233
"sft",
@@ -38,7 +39,6 @@ def test_trl(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None:
3839
"--gradient_accumulation_steps=1",
3940
"--output_dir=/opt/huggingface/trained_model",
4041
"--logging_steps=1",
41-
"--num_train_epochs=-1",
4242
"--max_steps=10",
4343
"--gradient_checkpointing",
4444
],
@@ -81,14 +81,15 @@ def test_trl_peft(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None
8181
"""Adapted from https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py"""
8282
caplog.set_level(logging.INFO)
8383

84+
container_uri = os.getenv("TRAINING_DLC", None)
85+
if container_uri is None or container_uri == "":
86+
assert False, "TRAINING_DLC environment variable is not set"
87+
8488
client = docker.from_env()
8589

8690
logging.info("Running the container for TRL...")
8791
container = client.containers.run(
88-
os.getenv(
89-
"TRAINING_DLC",
90-
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310",
91-
),
92+
container_uri,
9293
command=[
9394
"trl",
9495
"sft",
@@ -100,7 +101,6 @@ def test_trl_peft(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None
100101
"--gradient_accumulation_steps=1",
101102
"--output_dir=/opt/huggingface/trained_model",
102103
"--logging_steps=1",
103-
"--num_train_epochs=-1",
104104
"--max_steps=10",
105105
"--gradient_checkpointing",
106106
"--use_peft",

tests/tei/test_tei.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,17 @@ def test_text_embeddings_inference(
3333
) -> None:
3434
caplog.set_level(logging.INFO)
3535

36+
container_uri = os.getenv("TEI_DLC", None)
37+
if container_uri is None or container_uri == "":
38+
assert False, "TEI_DLC environment variable is not set"
39+
3640
client = docker.from_env()
3741

3842
logging.info(
3943
f"Starting container for {text_embeddings_router_kwargs.get('MODEL_ID', None)}..."
4044
)
4145
container = client.containers.run(
42-
os.getenv(
43-
"TEI_DLC",
44-
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-embeddings-inference-cpu.1-2"
45-
if not CUDA_AVAILABLE
46-
else "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-embeddings-inference-cu122.1-4.ubuntu2204",
47-
),
46+
container_uri,
4847
# TODO: udpate once the TEI DLCs is updated, as the current is still on revision:
4948
# https://github.com/huggingface/Google-Cloud-Containers/blob/517b8728725f6249774dcd46ee8d7ede8d95bb70/containers/tei/cpu/1.2.2/Dockerfile
5049
# and it exposes the 80 port and uses the /data directory instead of /tmp

tests/tgi/test_tgi.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def test_text_generation_inference(
4242
) -> None:
4343
caplog.set_level(logging.INFO)
4444

45+
container_uri = os.getenv("TGI_DLC", None)
46+
if container_uri is None or container_uri == "":
47+
assert False, "TGI_DLC environment variable is not set"
48+
4549
client = docker.from_env()
4650

4751
# If the GPU compute capability is lower than 8.0 (Ampere), then set `USE_FLASH_ATTENTION=false`
@@ -56,10 +60,7 @@ def test_text_generation_inference(
5660
f"Starting container for {text_generation_launcher_kwargs.get('MODEL_ID', None)}..."
5761
)
5862
container = client.containers.run(
59-
os.getenv(
60-
"TGI_DLC",
61-
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-generation-inference-cu121.2-2.ubuntu2204.py310",
62-
),
63+
container_uri,
6364
ports={8080: 8080},
6465
environment=text_generation_launcher_kwargs,
6566
healthcheck={

0 commit comments

Comments
 (0)