Skip to content

Commit 91694e0

Browse files
committed
adding stats list to snpe deployment
1 parent 3ab80c9 commit 91694e0

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

app/flows/run_snpe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def log_task(
199199
posterior_samples: torch.Tensor,
200200
input_parameters: RootCalibrationModel,
201201
observed_values: list,
202+
statistics_list: list[SummaryStatisticsModel],
202203
names: list[str],
203204
limits: list[tuple],
204205
simulation_uuid: str,
@@ -220,6 +221,8 @@ def log_task(
220221
The root calibration data model.
221222
observed_values (list):
222223
The list of observed_values.
224+
statistics_list (list[SummaryStatisticsModel]):
225+
The list of summary statistics.
223226
names (list[str]):
224227
The parameter names.
225228
limits (list[tuple]):
@@ -327,12 +330,14 @@ def log_task(
327330
parameter_specs, input_parameters
328331
)
329332

333+
statistics_list = [statistic.dict() for statistic in statistics_list]
330334
parameter_intervals["inference_type"] = "summary_statistics"
331335
artifacts = {}
332336
for obj, name in [
333337
(inference, "inference"),
334338
(posterior, "posterior"),
335339
(parameter_intervals, "parameter_intervals"),
340+
(statistics_list, "statistics_list"),
336341
]:
337342
outfile = osp.join(outdir, f"{time_now}-{TASK}_{name}.pkl")
338343
artifacts[name] = outfile
@@ -380,6 +385,7 @@ def run_snpe(input_parameters: RootCalibrationModel, simulation_uuid: str) -> No
380385
posterior_samples,
381386
input_parameters,
382387
observed_values,
388+
statistics_list,
383389
names,
384390
limits,
385391
simulation_uuid,

deeprootgen/calibration/model_versioning.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
This module defines MLflow compatible models for versioning and deployment as microservices.
44
"""
55

6+
from typing import Any
7+
68
import bentoml
79
import mlflow
810
import numpy as np
@@ -253,14 +255,14 @@ def load_context(self, context: Context) -> None:
253255
"""
254256
import joblib
255257

256-
loaded_data = context.artifacts["inference"]
257-
self.inference = joblib.load(loaded_data)
258-
259-
loaded_data = context.artifacts["posterior"]
260-
self.posterior = joblib.load(loaded_data)
258+
def load_data(k: str) -> Any:
259+
artifact = context.artifacts[k]
260+
return joblib.load(artifact)
261261

262-
loaded_data = context.artifacts["parameter_intervals"]
263-
self.parameter_intervals = joblib.load(loaded_data)
262+
self.inference = load_data("inference")
263+
self.posterior = load_data("posterior")
264+
self.parameter_intervals = load_data("parameter_intervals")
265+
self.statistics_list = load_data("statistics_list")
264266

265267
def predict(
266268
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
@@ -283,12 +285,14 @@ def predict(
283285
pd.DataFrame:
284286
The model prediction.
285287
"""
286-
if (
287-
self.inference is None
288-
or self.posterior is None
289-
or self.parameter_intervals is None
290-
):
291-
raise ValueError(f"The {self.task} calibrator has not been loaded.")
288+
for prop in [
289+
self.inference,
290+
self.posterior,
291+
self.parameter_intervals,
292+
self.statistics_list,
293+
]:
294+
if prop is None:
295+
raise ValueError(f"The {self.task} calibrator has not been loaded.")
292296

293297
observed_values = model_input["statistic_value"].values
294298
posterior_samples = self.posterior.sample((50,), x=observed_values)

0 commit comments

Comments
 (0)