Skip to content

Commit 17aa1d9

Browse files
EZoniRemiLehe
andauthored
Add start/end date selectors for experiment filtering (#370)
Co-authored-by: Remi Lehe <remi.lehe@normalesup.org>
1 parent 1996f52 commit 17aa1d9

File tree

5 files changed

+63
-6
lines changed

5 files changed

+63
-6
lines changed

dashboard/app.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def update(
105105
ctrl.figure_update(fig)
106106

107107

108-
@state.change("experiment")
108+
@state.change(
109+
"experiment",
110+
"experiment_date_range",
111+
)
109112
def update_on_change_experiment(**kwargs):
110113
# skip if triggered on server ready (all state variables marked as modified)
111114
if len(state.modified_keys) == 1:
@@ -361,14 +364,24 @@ def gui_setup():
361364
# add toolbar components
362365
with layout.toolbar:
363366
vuetify.VSpacer()
367+
# experiment selector
364368
vuetify.VSelect(
365369
v_model=("experiment",),
366370
label="Experiments",
367371
items=(experiments,),
368372
dense=True,
369373
hide_details=True,
370374
prepend_icon="mdi-atom",
371-
style="max-width: 250px",
375+
style="max-width: 250px; margin-right: 14px;",
376+
)
377+
# date range selector for experiment filtering
378+
vuetify.VDateInput(
379+
v_model=("experiment_date_range",),
380+
label="Date range",
381+
multiple="range",
382+
dense=True,
383+
hide_details=True,
384+
style="max-width: 250px; margin-right: 14px;",
372385
)
373386
# set up router view
374387
with layout.content:

dashboard/model_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lume_model.models.ensemble import NNEnsemble
1212
from lume_model.models.gp_model import GPModel
1313
from trame.widgets import vuetify3 as vuetify
14-
from utils import verify_input_variables, timer, load_config_dict
14+
from utils import verify_input_variables, timer, load_config_dict, create_date_filter
1515
from error_manager import add_error
1616
from sfapi_manager import monitor_sfapi_job
1717
from state_manager import state
@@ -190,9 +190,13 @@ async def training_kernel(self):
190190
client_id=state.sfapi_client_id, secret=state.sfapi_key
191191
) as client:
192192
perlmutter = await client.compute(Machine.perlmutter)
193-
# Upload the config.yaml to nersc
193+
# upload the configuration file to NERSC
194194
config_dict = load_config_dict(state.experiment)
195195
config_dict["simulation_calibration"] = state.simulation_calibration
196+
# add date range filter to the configuration dictionary
197+
date_filter = create_date_filter(state.experiment_date_range)
198+
config_dict["date_filter"] = date_filter
199+
# define the target path on NERSC
196200
target_path = "/global/cfs/cdirs/m558/superfacility/model_training"
197201
[target_path] = await perlmutter.ls(target_path, directory=True)
198202
with tempfile.TemporaryDirectory() as temp_dir:

dashboard/state_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22
from trame.app import get_server
3+
from trame.widgets import vuetify3 as vuetify
34

45

56
EXPERIMENTS_PATH = Path.cwd().parent / "experiments/"
@@ -8,6 +9,7 @@
89
server = get_server(client_type="vue3")
910
state = server.state
1011
ctrl = server.controller
12+
vuetify.enable_lab() # Enable Labs components
1113

1214

1315
def initialize_state():
@@ -23,6 +25,7 @@ def initialize_state():
2325
][0]
2426
print(f"Setting default experiment to {default_experiment}...")
2527
state.experiment = default_experiment
28+
state.experiment_date_range = []
2629
# ML model
2730
state.model_type = "Neural Network (single)"
2831
state.model_training = False

dashboard/utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,47 @@ def load_variables(experiment):
7070
return (input_variables, output_variables, simulation_calibration)
7171

7272

73+
def create_date_filter(experiment_date_range):
74+
# build date filter if date range is set
75+
date_filter = {}
76+
if experiment_date_range:
77+
start_date = pd.to_datetime(experiment_date_range[0].to_datetime())
78+
start_date = start_date.to_pydatetime().replace(hour=0, minute=0, second=0)
79+
# VDateInput returns exclusive end date for date ranges:
80+
# - subtract 1 day for multi-date ranges with different start/end dates
81+
# - do not subtract anything (use end date as is) for single-date ranges
82+
end_date = pd.to_datetime(experiment_date_range[-1].to_datetime())
83+
end_date_correction = (
84+
pd.Timedelta(days=0)
85+
if len(experiment_date_range) == 1
86+
else pd.Timedelta(days=1)
87+
)
88+
end_date = end_date - end_date_correction
89+
end_date = end_date.to_pydatetime().replace(hour=23, minute=59, second=59)
90+
# remove timezone info to match naive datetime in database
91+
start_date = (
92+
start_date.replace(tzinfo=None) if start_date.tzinfo else start_date
93+
)
94+
end_date = end_date.replace(tzinfo=None) if end_date.tzinfo else end_date
95+
date_filter = {
96+
"date": {
97+
"$gte": start_date,
98+
"$lte": end_date,
99+
}
100+
}
101+
print(f"Filtering data between {start_date.date()} and {end_date.date()}...")
102+
return date_filter
103+
104+
73105
@timer
74106
def load_data(db):
75107
print("Loading data from database...")
108+
# create date filter if date range is set
109+
date_filter = create_date_filter(state.experiment_date_range)
76110
# load experiment and simulation data points in dataframes
77-
exp_data = pd.DataFrame(db[state.experiment].find({"experiment_flag": 1}))
111+
exp_data = pd.DataFrame(
112+
db[state.experiment].find({"experiment_flag": 1, **date_filter})
113+
)
78114
sim_data = pd.DataFrame(db[state.experiment].find({"experiment_flag": 0}))
79115
# Store '_id', 'date' as string
80116
for key in ["_id", "date"]:

ml/train_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ def write_model(model, model_type, experiment, db):
439439

440440
# Extract experimental and simulation data from the database as pandas dataframe
441441
db = connect_to_db(config_dict)
442-
df_exp = pd.DataFrame(db[experiment].find({"experiment_flag": 1}))
442+
date_filter = config_dict.get("date_filter", {})
443+
df_exp = pd.DataFrame(db[experiment].find({"experiment_flag": 1, **date_filter}))
443444
df_sim = pd.DataFrame(db[experiment].find({"experiment_flag": 0}))
444445

445446
# Apply simulation calibration to the simulation data

0 commit comments

Comments
 (0)