3
3
import os
4
4
from unittest .mock import patch
5
5
6
- from ols .utils .environments import configure_gradio_ui_envs
6
+ from ols .app .models .config import OLSConfig , ReferenceContent
7
+ from ols .utils .environments import configure_gradio_ui_envs , configure_hugging_face_envs
7
8
8
9
9
10
@patch .dict (os .environ , {"GRADIO_ANALYTICS_ENABLED" : "" , "MPLCONFIGDIR" : "" })
@@ -19,3 +20,40 @@ def test_configure_gradio_ui_envs():
19
20
# expected environment variables
20
21
assert os .environ .get ("GRADIO_ANALYTICS_ENABLED" , None ) == "false"
21
22
assert os .environ .get ("MPLCONFIGDIR" , None ) != ""
23
+
24
+
25
+ @patch .dict (os .environ , {"TRANSFORMERS_CACHE" : "" , "TRANSFORMERS_OFFLINE" : "" })
26
+ def test_configure_hugging_face_env_no_reference_content_set ():
27
+ """Test the function configure_hugging_face_envs."""
28
+ # setup before tested function is called
29
+ assert os .environ .get ("TRANSFORMERS_CACHE" , None ) == ""
30
+ assert os .environ .get ("TRANSFORMERS_OFFLINE" , None ) == ""
31
+
32
+ ols_config = OLSConfig ()
33
+ ols_config .reference_content = None
34
+
35
+ # call the tested function
36
+ configure_hugging_face_envs (ols_config )
37
+
38
+ # expected environment variables
39
+ assert os .environ .get ("TRANSFORMERS_CACHE" , None ) == ""
40
+ assert os .environ .get ("TRANSFORMERS_OFFLINE" , None ) == ""
41
+
42
+
43
+ @patch .dict (os .environ , {"TRANSFORMERS_CACHE" : "" , "TRANSFORMERS_OFFLINE" : "" })
44
+ def test_configure_hugging_face_env_reference_content_set ():
45
+ """Test the function configure_hugging_face_envs."""
46
+ # setup before tested function is called
47
+ assert os .environ .get ("TRANSFORMERS_CACHE" , None ) == ""
48
+ assert os .environ .get ("TRANSFORMERS_OFFLINE" , None ) == ""
49
+
50
+ ols_config = OLSConfig ()
51
+ ols_config .reference_content = ReferenceContent ()
52
+ ols_config .reference_content .embeddings_model_path = "foo"
53
+
54
+ # call the tested function
55
+ configure_hugging_face_envs (ols_config )
56
+
57
+ # expected environment variables
58
+ assert os .environ .get ("TRANSFORMERS_CACHE" , None ) == "foo"
59
+ assert os .environ .get ("TRANSFORMERS_OFFLINE" , None ) == "1"
0 commit comments