Skip to content

Commit 252d690

Browse files
authored
Merge pull request #306 from The-Strategy-Unit/299_add_metadata_to_saved_params
add metadata to saved params json file
2 parents 022aa8b + 16541b7 commit 252d690

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

docker_run.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,23 +139,30 @@ def _upload_results_json(self, results_file: str, metadata: dict) -> None:
139139
overwrite=True,
140140
)
141141

142-
def _upload_results_files(self, files: list) -> None:
142+
def _upload_results_files(self, files: list, metadata: dict) -> None:
143143
"""Upload the results
144144
145145
once the model has run, upload the files (parquet for model results and json for model params) to blob storage
146146
147147
:param files: list of files to be uploaded
148148
:type files: list
149+
:param metadata: the metadata to attach to the blob
150+
:type metadata: dict
149151
150152
"""
151153
container = self._get_container("results")
152154
for file in files:
153155
filename = file[8:]
156+
if file.endswith(".json"):
157+
metadata_to_use = metadata
158+
else:
159+
metadata_to_use = None
154160
with open(file, "rb") as f:
155161
container.upload_blob(
156162
f"aggregated-model-results/{self._app_version}/{filename}",
157163
f.read(),
158164
overwrite=True,
165+
metadata=metadata_to_use,
159166
)
160167

161168
def _upload_full_model_results(self) -> None:
@@ -203,7 +210,7 @@ def finish(
203210
if not isinstance(v, dict) and not isinstance(v, list)
204211
}
205212
self._upload_results_json(results_file, metadata)
206-
self._upload_results_files(saved_files)
213+
self._upload_results_files(saved_files, metadata)
207214
if save_full_model_results:
208215
self._upload_full_model_results()
209216
self._cleanup()

tests/test_docker_run.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,23 +229,38 @@ def test_RunWithAzureStorage_upload_results_json(mock_run_with_azure_storage, mo
229229
)
230230

231231

232-
def test_RunWithAzureStorage_upload_results_files(
233-
mock_run_with_azure_storage, mocker
234-
):
232+
def test_RunWithAzureStorage_upload_results_files(mock_run_with_azure_storage, mocker):
235233
# arrange
236234
s = mock_run_with_azure_storage
237235

238236
m = mocker.patch("docker_run.RunWithAzureStorage._get_container")
237+
metadata = {"k": "v"}
239238

240239
# act
241240
with patch("builtins.open", mock_open(read_data="data")) as mock_file:
242-
s._upload_results_files(["results/filename"])
241+
s._upload_results_files(["results/filename", "results/filename.json"], metadata)
243242

244243
# assert
245-
mock_file.assert_called_once_with("results/filename", "rb")
244+
mock_file.call_args_list == [
245+
call("results/filename", "rb"),
246+
call("results/filename.json", "rb"),
247+
]
246248
m.assert_called_once_with("results")
247-
m().upload_blob.assert_called_once_with(
248-
"aggregated-model-results/dev/filename", "data", overwrite=True
249+
m().upload_blob.assert_has_calls(
250+
[
251+
call(
252+
"aggregated-model-results/dev/filename",
253+
"data",
254+
overwrite=True,
255+
metadata=None,
256+
),
257+
call(
258+
"aggregated-model-results/dev/filename.json",
259+
"data",
260+
overwrite=True,
261+
metadata=metadata,
262+
),
263+
]
249264
)
250265

251266

@@ -319,7 +334,7 @@ def test_RunWithAzureStorage_finish_save_full_model_results_false(
319334

320335
# assert
321336
m1.assert_called_once_with("results_file", metadata)
322-
m2.assert_called_once_with(["saved_files"])
337+
m2.assert_called_once_with(["saved_files"], metadata)
323338
m3.assert_not_called()
324339
m4.assert_called_once()
325340

@@ -346,7 +361,7 @@ def test_RunWithAzureStorage_finish_save_full_model_results_true(
346361

347362
# assert
348363
m1.assert_called_once_with("results_file", metadata)
349-
m2.assert_called_once_with(["saved_files"])
364+
m2.assert_called_once_with(["saved_files"], metadata)
350365
m3.assert_called_once()
351366
m4.assert_called_once()
352367

0 commit comments

Comments
 (0)