Skip to content

Commit 1176945

Browse files
authored
Merge pull request #453 from The-Strategy-Unit/fix_config
2 parents b6d0773 + 7075f5c commit 1176945

File tree

6 files changed

+122
-59
lines changed

6 files changed

+122
-59
lines changed

src/nhp/docker/__main__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import threading
77

8-
from nhp.docker import config
8+
from nhp.docker.config import Config
99
from nhp.docker.run import RunWithAzureStorage, RunWithLocalStorage
1010
from nhp.model.run import run_all
1111

@@ -37,7 +37,7 @@ def parse_args():
3737
return parser.parse_args()
3838

3939

40-
def main():
40+
def main(config: Config = Config()):
4141
"""The main method."""
4242
args = parse_args()
4343

@@ -50,7 +50,7 @@ def main():
5050
if args.local_storage:
5151
runner = RunWithLocalStorage(args.params_file)
5252
else:
53-
runner = RunWithAzureStorage(args.params_file, config.APP_VERSION)
53+
runner = RunWithAzureStorage(args.params_file, config)
5454

5555
logging.info("running model for: %s", args.params_file)
5656
logging.info("container timeout: %ds", config.CONTAINER_TIMEOUT_SECONDS)
@@ -73,12 +73,13 @@ def init():
7373
"""Method for calling main."""
7474
if __name__ == "__main__":
7575
exc = None
76+
config = Config()
7677
try:
7778
# start a timer to kill the container if we reach a timeout
7879
t = threading.Timer(config.CONTAINER_TIMEOUT_SECONDS, _exit_container)
7980
t.start()
8081
# run the model
81-
main()
82+
main(config)
8283
except Exception as e:
8384
logging.error("An error occurred: %s", str(e))
8485
exc = e

src/nhp/docker/config.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
11
"""config values for docker container."""
22

3+
import os
4+
35
import dotenv
46

5-
__config_values = dotenv.dotenv_values()
67

8+
class Config:
9+
"""Configuration class for Docker container."""
10+
11+
__DEFAULT_CONTAINER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
12+
13+
def __init__(self):
14+
"""Configuration settings for the Docker container."""
15+
dotenv.load_dotenv()
16+
17+
self._app_version = os.environ.get("APP_VERSION", "dev")
18+
self._data_version = os.environ.get("DATA_VERSION", "dev")
19+
self._storage_account = os.environ.get("STORAGE_ACCOUNT")
20+
21+
self._container_timeout_seconds = os.environ.get("CONTAINER_TIMEOUT_SECONDS")
22+
23+
@property
24+
def APP_VERSION(self) -> str:
25+
"""What is the version of the app?"""
26+
return self._app_version
727

8-
APP_VERSION: str = __config_values.get("APP_VERSION", "dev") # type: ignore
9-
DATA_VERSION: str = __config_values.get("DATA_VERSION", "dev") # type: ignore
28+
@property
29+
def DATA_VERSION(self) -> str:
30+
"""What version of the data are we using?"""
31+
return self._data_version
1032

11-
STORAGE_ACCOUNT: str | None = __config_values.get("STORAGE_ACCOUNT", None)
33+
@property
34+
def STORAGE_ACCOUNT(self) -> str:
35+
"""What is the name of the storage account?"""
36+
if self._storage_account is None:
37+
raise ValueError("STORAGE_ACCOUNT environment variable must be set")
38+
return self._storage_account
1239

13-
__DEFAULT_CONTAINER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
14-
CONTAINER_TIMEOUT_SECONDS = int(
15-
__config_values.get("CONTAINER_TIMEOUT_SECONDS", __DEFAULT_CONTAINER_TIMEOUT_SECONDS) # type: ignore
16-
)
40+
@property
41+
def CONTAINER_TIMEOUT_SECONDS(self) -> int:
42+
"""How long should the container run before timing out?"""
43+
t = self._container_timeout_seconds
44+
return self.__DEFAULT_CONTAINER_TIMEOUT_SECONDS if t is None else int(t)

src/nhp/docker/run.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from azure.storage.blob import BlobServiceClient
1313
from azure.storage.filedatalake import DataLakeServiceClient
1414

15-
from nhp.docker import config
15+
from nhp.docker.config import Config
1616
from nhp.model.helpers import load_params
1717
from nhp.model.run import noop_progress_callback
1818

@@ -51,27 +51,35 @@ def progress_callback(self) -> Callable[[Any], Callable[[Any], None]]:
5151
class RunWithAzureStorage:
5252
"""Methods for running with azure storage."""
5353

54-
def __init__(self, filename: str, app_version: str = "dev"):
54+
def __init__(self, filename: str, config: Config = Config()):
5555
"""Initialise RunWithAzureStorage.
5656
5757
:param filename:
5858
:type filename: str
59-
:param app_version: the version of the app, where we will load data from. defaults to "dev"
60-
:type app_version: str, optional
59+
:param config: The configuration for the run
60+
:type config: Config
6161
"""
6262
logging.getLogger("azure.storage.common.storageclient").setLevel(logging.WARNING)
6363
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(
6464
logging.WARNING
6565
)
66+
self._config = config
6667

67-
self._app_version = re.sub("(\\d+\\.\\d+)\\..*", "\\1", app_version)
68+
self._app_version = re.sub("(\\d+\\.\\d+)\\..*", "\\1", config.APP_VERSION)
69+
70+
self._blob_storage_account_url = (
71+
f"https://{self._config.STORAGE_ACCOUNT}.blob.core.windows.net"
72+
)
73+
self._adls_storage_account_url = (
74+
f"https://{self._config.STORAGE_ACCOUNT}.dfs.core.windows.net"
75+
)
6876

6977
self.params = self._get_params(filename)
7078
self._get_data(self.params["start_year"], self.params["dataset"])
7179

7280
def _get_container(self, container_name: str):
7381
return BlobServiceClient(
74-
account_url=f"https://{config.STORAGE_ACCOUNT}.blob.core.windows.net",
82+
account_url=self._blob_storage_account_url,
7583
credential=DefaultAzureCredential(),
7684
).get_container_client(container_name)
7785

@@ -103,11 +111,11 @@ def _get_data(self, year: str, dataset: str) -> None:
103111
"""
104112
logging.info("downloading data (%s / %s)", year, dataset)
105113
fs_client = DataLakeServiceClient(
106-
account_url=f"https://{config.STORAGE_ACCOUNT}.dfs.core.windows.net",
114+
account_url=self._adls_storage_account_url,
107115
credential=DefaultAzureCredential(),
108116
).get_file_system_client("data")
109117

110-
version = config.DATA_VERSION
118+
version = self._config.DATA_VERSION
111119

112120
paths = [p.name for p in fs_client.get_paths(version, recursive=False)]
113121

tests/unit/nhp/docker/test___main__.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""test docker run."""
22

33
import time
4-
from unittest.mock import patch
4+
from unittest.mock import Mock, patch
55

66
import pytest
77

8-
from nhp.docker import config
98
from nhp.docker.__main__ import main, parse_args
109

1110

@@ -93,6 +92,12 @@ def test_main_azure(mocker):
9392
rwls = mocker.patch("nhp.docker.__main__.RunWithLocalStorage")
9493
rwas = mocker.patch("nhp.docker.__main__.RunWithAzureStorage")
9594

95+
config = Mock()
96+
config.APP_VERSION = "dev"
97+
config.DATA_VERSION = "dev"
98+
config.CONTAINER_TIMEOUT_SECONDS = 3600
99+
config.STORAGE_ACCOUNT = "sa"
100+
96101
params = {
97102
"model_runs": 256,
98103
"start_year": 2019,
@@ -108,11 +113,11 @@ def test_main_azure(mocker):
108113
)
109114

110115
# act
111-
main()
116+
main(config)
112117

113118
# assert
114119
rwls.assert_not_called()
115-
rwas.assert_called_once_with("params.json", "dev")
120+
rwas.assert_called_once_with("params.json", config)
116121

117122
s = rwas()
118123
ru_m.assert_called_once_with(params, "data", s.progress_callback(), False)
@@ -121,6 +126,9 @@ def test_main_azure(mocker):
121126

122127
def test_init(mocker):
123128
"""It should run the main method if __name__ is __main__."""
129+
config = mocker.patch("nhp.docker.__main__.Config")
130+
config().CONTAINER_TIMEOUT_SECONDS = 3600
131+
124132
import nhp.docker.__main__ as r
125133

126134
main_mock = mocker.patch("nhp.docker.__main__.main")
@@ -130,31 +138,33 @@ def test_init(mocker):
130138

131139
with patch.object(r, "__name__", "__main__"):
132140
r.init() # should call main
133-
main_mock.assert_called_once()
141+
main_mock.assert_called_once_with(config())
134142

135143

136144
def test_init_timeout_call_exit(mocker):
137-
config.CONTAINER_TIMEOUT_SECONDS = 0.1
145+
config = mocker.patch("nhp.docker.__main__.Config")
146+
config().CONTAINER_TIMEOUT_SECONDS = 0.1
138147

139148
import nhp.docker.__main__ as r
140149

141150
main_mock = mocker.patch("nhp.docker.__main__.main")
142151
exit_container_mock = mocker.patch("nhp.docker.__main__._exit_container")
143-
main_mock.side_effect = lambda: time.sleep(0.2)
152+
main_mock.side_effect = lambda *args, **kwargs: time.sleep(0.2)
144153
with patch.object(r, "__name__", "__main__"):
145154
r.init()
146155

147156
exit_container_mock.assert_called_once()
148157

149158

150159
def test_init_timeout_dont_call_exit(mocker):
151-
config.CONTAINER_TIMEOUT_SECONDS = 0.1
152-
153160
import nhp.docker.__main__ as r
154161

162+
config = mocker.patch("nhp.docker.__main__.Config")
163+
config().CONTAINER_TIMEOUT_SECONDS = 0.1
164+
155165
main_mock = mocker.patch("nhp.docker.__main__.main")
156166
exit_container_mock = mocker.patch("nhp.docker.__main__._exit_container")
157-
main_mock.side_effect = lambda: time.sleep(0.02)
167+
main_mock.side_effect = lambda *args, **kwargs: time.sleep(0.02)
158168
with patch.object(r, "__name__", "__main__"):
159169
r.init()
160170

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,57 @@
1-
import importlib
1+
import os
22
from unittest.mock import patch
33

4-
import nhp.docker.config as c
4+
import pytest
5+
6+
from nhp.docker.config import Config
57

68

79
def test_config_sets_values_from_envvars(mocker):
810
# arrange
9-
mocker.patch(
10-
"dotenv.dotenv_values",
11-
return_value={
11+
mocker.patch("dotenv.load_dotenv")
12+
13+
# act
14+
with patch.dict(
15+
os.environ,
16+
{
1217
"APP_VERSION": "app version",
1318
"DATA_VERSION": "data version",
1419
"STORAGE_ACCOUNT": "storage account",
1520
"CONTAINER_TIMEOUT_SECONDS": "123",
1621
},
17-
)
18-
19-
# act
20-
importlib.reload(c)
22+
):
23+
config = Config()
2124

2225
# assert
23-
assert c.APP_VERSION == "app version"
24-
assert c.DATA_VERSION == "data version"
25-
assert c.STORAGE_ACCOUNT == "storage account"
26-
assert c.CONTAINER_TIMEOUT_SECONDS == 123
26+
assert config.APP_VERSION == "app version"
27+
assert config.DATA_VERSION == "data version"
28+
assert config.STORAGE_ACCOUNT == "storage account"
29+
assert config.CONTAINER_TIMEOUT_SECONDS == 123
2730

2831

2932
def test_config_uses_default_values(mocker):
3033
# arrange
31-
mocker.patch("dotenv.dotenv_values", return_value={})
34+
mocker.patch("dotenv.load_dotenv")
3235

3336
# act
34-
importlib.reload(c)
37+
config = Config()
3538

3639
# assert
37-
assert c.APP_VERSION == "dev"
38-
assert c.DATA_VERSION == "dev"
39-
assert not c.STORAGE_ACCOUNT
40-
assert c.CONTAINER_TIMEOUT_SECONDS == 3600
40+
assert config.APP_VERSION == "dev"
41+
assert config.DATA_VERSION == "dev"
42+
43+
with pytest.raises(ValueError, match="STORAGE_ACCOUNT environment variable must be set"):
44+
config.STORAGE_ACCOUNT
45+
46+
assert config.CONTAINER_TIMEOUT_SECONDS == 3600
4147

4248

4349
def test_config_calls_dotenv_load(mocker):
4450
# arrange
45-
m = mocker.patch("dotenv.dotenv_values", return_value={})
51+
m = mocker.patch("dotenv.load_dotenv")
4652

4753
# act
48-
importlib.reload(c)
54+
config = Config()
4955

5056
# assert
5157
m.assert_called_once()

tests/unit/nhp/docker/test_run.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,25 @@ def mock_run_with_azure_storage():
5959
rwas.params = {}
6060
rwas._app_version = "dev"
6161

62+
rwas._config = Mock()
63+
rwas._config.APP_VERSION = "dev"
64+
rwas._config.DATA_VERSION = "dev"
65+
rwas._config.STORAGE_ACCOUNT = "sa"
66+
rwas._config.CONTAINER_TIMEOUT_SECONDS = 3600
67+
68+
rwas._blob_storage_account_url = "https://sa.blob.core.windows.net"
69+
rwas._adls_storage_account_url = "https://sa.dfs.core.windows.net"
70+
6271
return rwas
6372

6473

65-
@pytest.mark.parametrize(
66-
"args, expected_version", [(["filename"], "dev"), (["filename", "v0.3.5"], "v0.3")]
67-
)
68-
def test_RunWithAzureStorage_init(mocker, args, expected_version):
74+
@pytest.mark.parametrize("actual_version, expected_version", [("dev", "dev"), ("v0.3.5", "v0.3")])
75+
def test_RunWithAzureStorage_init(mocker, actual_version, expected_version):
6976
# arrange
77+
config = mocker.patch("nhp.docker.run.Config")
78+
config().APP_VERSION = actual_version
79+
config().STORAGE_ACCOUNT = "sa"
80+
7081
expected_params = {"start_year": 2020, "dataset": "synthetic"}
7182
gpm = mocker.patch(
7283
"nhp.docker.run.RunWithAzureStorage._get_params",
@@ -75,11 +86,13 @@ def test_RunWithAzureStorage_init(mocker, args, expected_version):
7586
gdm = mocker.patch("nhp.docker.run.RunWithAzureStorage._get_data")
7687

7788
# act
78-
s = RunWithAzureStorage(*args) # type: ignore
89+
s = RunWithAzureStorage("filename", config())
7990

8091
# assert
8192
assert s._app_version == expected_version
8293
assert s.params == expected_params
94+
assert s._blob_storage_account_url == "https://sa.blob.core.windows.net"
95+
assert s._adls_storage_account_url == "https://sa.dfs.core.windows.net"
8396

8497
gpm.assert_called_once_with("filename")
8598

@@ -97,7 +110,6 @@ def test_RunWithAzureStorage_get_container(mock_run_with_azure_storage, mocker):
97110

98111
dac_m = mocker.patch("nhp.docker.run.DefaultAzureCredential", return_value="cred")
99112

100-
config.STORAGE_ACCOUNT = "sa"
101113
# act
102114
actual = s._get_container("container")
103115

@@ -163,8 +175,6 @@ def paths_helper(i):
163175

164176
dac_m = mocker.patch("nhp.docker.run.DefaultAzureCredential", return_value="cred")
165177

166-
config.STORAGE_ACCOUNT = "sa"
167-
168178
# act
169179
with patch("builtins.open", mock_open()) as mock_file:
170180
s._get_data(2020, "synthetic")

0 commit comments

Comments
 (0)