Skip to content

Commit 6ceda93

Browse files
committed
OLS-1229: HuggingFace environment settings
Signed-off-by: Pavel Tisnovsky <[email protected]>
1 parent 2a1aa43 commit 6ceda93

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

ols/utils/environments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
import tempfile
55

6+
import ols.app.models.config as config_model
7+
68

79
def configure_gradio_ui_envs() -> None:
810
"""Configure GradioUI framework environment variables."""
@@ -15,3 +17,17 @@ def configure_gradio_ui_envs() -> None:
1517
# Fixes: https://issues.redhat.com/browse/OLS-301
1618
tempdir = os.path.join(tempfile.gettempdir(), "matplotlib")
1719
os.environ["MPLCONFIGDIR"] = tempdir
20+
21+
22+
def configure_hugging_face_envs(ols_config: config_model.OLSConfig) -> None:
23+
"""Configure HuggingFace library environment variables."""
24+
if (
25+
ols_config
26+
and hasattr(ols_config, "reference_content")
27+
and hasattr(ols_config.reference_content, "embeddings_model_path")
28+
and ols_config.reference_content.embeddings_model_path
29+
):
30+
os.environ["TRANSFORMERS_CACHE"] = str(
31+
ols_config.reference_content.embeddings_model_path
32+
)
33+
os.environ["TRANSFORMERS_OFFLINE"] = "1"

runner.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,13 @@
55
import threading
66
from pathlib import Path
77

8-
import ols.app.models.config as config_model
98
from ols.runners.uvicorn import start_uvicorn
109
from ols.src.auth.auth import use_k8s_auth
1110
from ols.utils.certificates import generate_certificates_file
12-
from ols.utils.environments import configure_gradio_ui_envs
11+
from ols.utils.environments import configure_gradio_ui_envs, configure_hugging_face_envs
1312
from ols.utils.logging_configurator import configure_logging
1413

1514

16-
def configure_hugging_face_envs(ols_config: config_model.OLSConfig) -> None:
17-
"""Configure HuggingFace library environment variables."""
18-
if (
19-
ols_config
20-
and hasattr(ols_config, "reference_content")
21-
and hasattr(ols_config.reference_content, "embeddings_model_path")
22-
and ols_config.reference_content.embeddings_model_path
23-
):
24-
os.environ["TRANSFORMERS_CACHE"] = str(
25-
ols_config.reference_content.embeddings_model_path
26-
)
27-
os.environ["TRANSFORMERS_OFFLINE"] = "1"
28-
29-
3015
def load_index():
3116
"""Load the index."""
3217
# accessing the config's rag_index property will trigger the loading

tests/unit/utils/test_environments.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os
44
from unittest.mock import patch
55

6-
from ols.utils.environments import configure_gradio_ui_envs
6+
from ols.app.models.config import OLSConfig, ReferenceContent
7+
from ols.utils.environments import configure_gradio_ui_envs, configure_hugging_face_envs
78

89

910
@patch.dict(os.environ, {"GRADIO_ANALYTICS_ENABLED": "", "MPLCONFIGDIR": ""})
@@ -19,3 +20,40 @@ def test_configure_gradio_ui_envs():
1920
# expected environment variables
2021
assert os.environ.get("GRADIO_ANALYTICS_ENABLED", None) == "false"
2122
assert os.environ.get("MPLCONFIGDIR", None) != ""
23+
24+
25+
@patch.dict(os.environ, {"TRANSFORMERS_CACHE": "", "TRANSFORMERS_OFFLINE": ""})
26+
def test_configure_hugging_face_env_no_reference_content_set():
27+
"""Test the function configure_hugging_face_envs."""
28+
# setup before tested function is called
29+
assert os.environ.get("TRANSFORMERS_CACHE", None) == ""
30+
assert os.environ.get("TRANSFORMERS_OFFLINE", None) == ""
31+
32+
ols_config = OLSConfig()
33+
ols_config.reference_content = None
34+
35+
# call the tested function
36+
configure_hugging_face_envs(ols_config)
37+
38+
# expected environment variables
39+
assert os.environ.get("TRANSFORMERS_CACHE", None) == ""
40+
assert os.environ.get("TRANSFORMERS_OFFLINE", None) == ""
41+
42+
43+
@patch.dict(os.environ, {"TRANSFORMERS_CACHE": "", "TRANSFORMERS_OFFLINE": ""})
44+
def test_configure_hugging_face_env_reference_content_set():
45+
"""Test the function configure_hugging_face_envs."""
46+
# setup before tested function is called
47+
assert os.environ.get("TRANSFORMERS_CACHE", None) == ""
48+
assert os.environ.get("TRANSFORMERS_OFFLINE", None) == ""
49+
50+
ols_config = OLSConfig()
51+
ols_config.reference_content = ReferenceContent()
52+
ols_config.reference_content.embeddings_model_path = "foo"
53+
54+
# call the tested function
55+
configure_hugging_face_envs(ols_config)
56+
57+
# expected environment variables
58+
assert os.environ.get("TRANSFORMERS_CACHE", None) == "foo"
59+
assert os.environ.get("TRANSFORMERS_OFFLINE", None) == "1"

0 commit comments

Comments
 (0)