Skip to content

Commit a2a13f8

Browse files
committed
adding deployment for surrogate
1 parent b974c85 commit a2a13f8

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

app/conf/deployments_form/common.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ components:
534534
children: Neural density calibrator
535535
color: primary
536536
className: me-1
537-
- id: call-surrogate-modelling-button
537+
- id: call-surrogate-button
538538
label: Surrogate model
539539
help: Make a web request to the deployed surrogate model calibrator
540540
class_name: dash_bootstrap_components.Button

app/pages/deployments_root_system.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,53 @@ def call_snpe(n_clicks: int | list[int], statistics_inputs: list[float]) -> Call
487487
return dcc.send_file(outfile), True, f"Calling {task} calibrator"
488488

489489

490+
@callback(
491+
Output(f"{PAGE_ID}-download-results", "data", allow_duplicate=True),
492+
Output(f"{PAGE_ID}-load-toast", "is_open", allow_duplicate=True),
493+
Output(f"{PAGE_ID}-load-toast", "children", allow_duplicate=True),
494+
Input({"index": f"{PAGE_ID}-call-surrogate-button", "type": ALL}, "n_clicks"),
495+
State({"type": f"{PAGE_ID}-parameters", "index": ALL}, "value"),
496+
prevent_initial_call=True,
497+
)
498+
def call_surrogate(
499+
n_clicks: int | list[int], statistics_inputs: list[float]
500+
) -> Callable:
501+
"""Call the surrogate model endpoint.
502+
503+
Args:
504+
n_clicks (int | list[int]):
505+
The number of form clicks.
506+
statistics_inputs (list[float]):
507+
The list of summary statistic values.
508+
509+
Returns:
510+
Callable:
511+
The form data.
512+
"""
513+
if n_clicks is None or len(n_clicks) == 0: # type: ignore
514+
return no_update
515+
516+
if n_clicks[0] is None or n_clicks[0] == 0: # type: ignore
517+
return no_update
518+
519+
endpoint = os.environ.get(
520+
"DEPLOYMENT_SURROGATE_INTERNAL_LINK", "http://surrogate:3000"
521+
)
522+
523+
app = get_app()
524+
statistics_form = app.settings[FORM_NAME]
525+
summary_statistics = {}
526+
for i, child in enumerate(statistics_form.components["parameters"]["children"]):
527+
statistic_name = child["param"]
528+
statistic_value = statistics_inputs[i]
529+
summary_statistics[statistic_name] = statistic_value
530+
531+
json = {"data": [summary_statistics]}
532+
task = "surrogate"
533+
outfile = endpoint_predict(task, endpoint, json)
534+
return dcc.send_file(outfile), True, f"Calling {task} calibrator"
535+
536+
490537
######################################
491538
# Layout
492539
######################################

0 commit comments

Comments
 (0)