Skip to content

Commit dbef496

Browse files
author
Sarah Krebs
committed
Allow multiple seeds (pareto front)
1 parent db6f5b3 commit dbef496

File tree

2 files changed

+133
-44
lines changed

2 files changed

+133
-44
lines changed

deepcave/plugins/objective/pareto_front.py

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -183,20 +183,38 @@ def get_filter_layout(register: Callable) -> List[Any]:
183183
The layouts for the filter block.
184184
"""
185185
return [
186-
html.Div(
186+
dbc.Row(
187187
[
188-
dbc.Label("Show all configurations"),
189-
help_button(
190-
"Additionally to the pareto front, also the other configurations "
191-
"are displayed. This makes it easier to see the performance "
192-
"differences."
188+
dbc.Col(
189+
[
190+
dbc.Label("Show all configurations"),
191+
help_button(
192+
"Additionally to the pareto front, also the other configurations "
193+
"are displayed. This makes it easier to see the performance "
194+
"differences."
195+
),
196+
dbc.Select(
197+
id=register("show_all", ["value", "options"]),
198+
placeholder="Select ...",
199+
),
200+
],
201+
md=6,
193202
),
194-
dbc.Select(
195-
id=register("show_all", ["value", "options"]),
196-
placeholder="Select ...",
203+
dbc.Col(
204+
[
205+
dbc.Label("Show error bars"),
206+
help_button(
207+
"Show error bars In the case of non-deterministic runs with "
208+
"multiple seeds evaluated per configuration."
209+
),
210+
dbc.Select(
211+
id=register("show_error", ["value", "options"]),
212+
placeholder="Select ...",
213+
),
214+
],
215+
md=6,
197216
),
198217
],
199-
className="mb-3",
200218
),
201219
dbc.Row(
202220
[
@@ -251,6 +269,7 @@ def load_inputs(self) -> Dict[str, Dict[str, Any]]:
251269
"value": self.budget_options[-1]["value"],
252270
},
253271
"show_all": {"options": get_select_options(binary=True), "value": "false"},
272+
"show_error": {"options": get_select_options(binary=True), "value": "false"},
254273
"show_runs": {"options": get_select_options(binary=True), "value": "true"},
255274
"show_groups": {"options": get_select_options(binary=True), "value": "true"},
256275
}
@@ -291,22 +310,29 @@ def process(run, inputs) -> Dict[str, Any]: # type: ignore
291310
objective_id_2 = inputs["objective_id_2"]
292311
objective_2 = run.get_objective(objective_id_2)
293312

294-
points: Union[List, np.ndarray] = []
295-
config_ids: Union[List, np.ndarray] = []
296-
for config_id, costs in run.get_all_costs(budget, statuses=[Status.SUCCESS]).items():
297-
points += [[costs[objective_id_1], costs[objective_id_2]]]
298-
config_ids += [config_id]
313+
points_avg: Union[List, np.ndarray] = []
314+
points_std: Union[List, np.ndarray] = []
315+
config_ids: Union[List, np.ndarray] = list(
316+
run.get_configs(budget, statuses=[Status.SUCCESS]).keys()
317+
)
318+
319+
for config_id in config_ids:
320+
avg_costs, std_costs = run.get_avg_costs(config_id, budget, statuses=[Status.SUCCESS])
321+
points_avg += [[avg_costs[objective_id_1], avg_costs[objective_id_2]]]
322+
points_std += [[std_costs[objective_id_1], std_costs[objective_id_2]]]
299323

300-
points = np.array(points)
324+
points_avg = np.array(points_avg)
325+
points_std = np.array(points_std)
301326
config_ids = np.array(config_ids)
302327

303328
# Sort the points s.t. x axis is monotonically increasing
304-
sorted_idx = np.argsort(points[:, 0])
305-
points = points[sorted_idx]
329+
sorted_idx = np.argsort(points_avg[:, 0])
330+
points_avg = points_avg[sorted_idx]
331+
points_std = points_std[sorted_idx]
306332
config_ids = config_ids[sorted_idx]
307333

308-
is_front: np.ndarray = np.ones(points.shape[0], dtype=bool)
309-
for point_idx, costs in enumerate(points):
334+
is_front: np.ndarray = np.ones(points_avg.shape[0], dtype=bool)
335+
for point_idx, costs in enumerate(points_avg):
310336
if is_front[point_idx]:
311337
# Keep any point with a lower/upper cost
312338
# This loop is a little bit complicated than
@@ -316,9 +342,9 @@ def process(run, inputs) -> Dict[str, Any]: # type: ignore
316342
select = None
317343
for idx, (objective, cost) in enumerate(zip([objective_1, objective_2], costs)):
318344
if objective.optimize == "upper":
319-
select2 = np.any(points[is_front][:, idx, np.newaxis] > [cost], axis=1)
345+
select2 = np.any(points_avg[is_front][:, idx, np.newaxis] > [cost], axis=1)
320346
else:
321-
select2 = np.any(points[is_front][:, idx, np.newaxis] < [cost], axis=1)
347+
select2 = np.any(points_avg[is_front][:, idx, np.newaxis] < [cost], axis=1)
322348

323349
if select is None:
324350
select = select2
@@ -331,7 +357,8 @@ def process(run, inputs) -> Dict[str, Any]: # type: ignore
331357
is_front[point_idx] = True
332358

333359
return {
334-
"points": points.tolist(),
360+
"points_avg": points_avg.tolist(),
361+
"points_std": points_std.tolist(),
335362
"pareto_points": is_front.tolist(),
336363
"config_ids": config_ids.tolist(),
337364
}
@@ -380,6 +407,7 @@ def load_outputs(runs, inputs, outputs) -> go.Figure: # type: ignore
380407
The output figure.
381408
"""
382409
show_all = inputs["show_all"]
410+
show_error = inputs["show_error"]
383411

384412
traces = []
385413
for idx, run in enumerate(runs):
@@ -392,31 +420,49 @@ def load_outputs(runs, inputs, outputs) -> go.Figure: # type: ignore
392420
if run.prefix != "group" and not show_runs:
393421
continue
394422

395-
points = np.array(outputs[run.id]["points"])
423+
points_avg = np.array(outputs[run.id]["points_avg"])
424+
points_std = np.array(outputs[run.id]["points_std"])
396425
config_ids = outputs[run.id]["config_ids"]
397426
pareto_config_ids = []
398427

399-
x, y = [], []
400-
x_pareto, y_pareto = [], []
428+
x, y, x_std, y_std = [], [], [], []
429+
x_pareto, y_pareto, x_pareto_std, y_pareto_std = [], [], [], []
401430

402431
pareto_points = outputs[run.id]["pareto_points"]
403432
for point_idx, pareto in enumerate(pareto_points):
404433
if pareto:
405-
x_pareto += [points[point_idx][0]]
406-
y_pareto += [points[point_idx][1]]
434+
x_pareto += [points_avg[point_idx][0]]
435+
y_pareto += [points_avg[point_idx][1]]
436+
x_pareto_std += [points_std[point_idx][0]]
437+
y_pareto_std += [points_std[point_idx][1]]
407438
pareto_config_ids += [config_ids[point_idx]]
408439
else:
409-
x += [points[point_idx][0]]
410-
y += [points[point_idx][1]]
440+
x += [points_avg[point_idx][0]]
441+
y += [points_avg[point_idx][1]]
442+
x_std += [points_std[point_idx][0]]
443+
y_std += [points_std[point_idx][1]]
411444

412445
color = get_color(idx, alpha=0.1)
413446
color_pareto = get_color(idx)
414447

415448
if show_all:
449+
error_x = (
450+
dict(array=x_std, color="rgba(0, 0, 0, 0.3)")
451+
if show_error and not all(value == 0.0 for value in x_std)
452+
else None
453+
)
454+
error_y = (
455+
dict(array=y_std, color="rgba(0, 0, 0, 0.3)")
456+
if show_error and not all(value == 0.0 for value in y_std)
457+
else None
458+
)
459+
416460
traces.append(
417461
go.Scatter(
418462
x=x,
419463
y=y,
464+
error_x=error_x,
465+
error_y=error_y,
420466
name=run.name,
421467
mode="markers",
422468
showlegend=False,
@@ -443,10 +489,23 @@ def load_outputs(runs, inputs, outputs) -> go.Figure: # type: ignore
443489
get_hovertext_from_config(run, config_id) for config_id in pareto_config_ids
444490
]
445491

492+
error_pareto_x = (
493+
dict(array=x_pareto_std, color="rgba(0, 0, 0, 0.3)")
494+
if show_error and not all(value == 0.0 for value in x_pareto_std)
495+
else None
496+
)
497+
error_pareto_y = (
498+
dict(array=y_pareto_std, color="rgba(0, 0, 0, 0.3)")
499+
if show_error and not all(value == 0.0 for value in y_pareto_std)
500+
else None
501+
)
502+
446503
traces.append(
447504
go.Scatter(
448505
x=x_pareto,
449506
y=y_pareto,
507+
error_x=error_pareto_x,
508+
error_y=error_pareto_y,
450509
name=run.name,
451510
line_shape=line_shape,
452511
showlegend=True,

deepcave/runs/__init__.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -387,21 +387,29 @@ def get_objective_names(self) -> List[str]:
387387
return [obj.name for obj in self.get_objectives()]
388388

389389
def get_configs(
390-
self, budget: Optional[Union[int, float]] = None, seed: Optional[int] = None
390+
self,
391+
budget: Optional[Union[int, float]] = None,
392+
seed: Optional[int] = None,
393+
statuses: Optional[Union[Status, List[Status]]] = None,
391394
) -> Dict[int, Configuration]:
392395
"""
393396
Get configurations of the run.
394397
395398
Optionally, only configurations which were evaluated
396-
on the passed budget are considered.
399+
on the passed budget, seed, and status are considered.
397400
398401
Parameters
399402
----------
400403
budget : Optional[Union[int, float]]
401-
Considered budget.
402-
By default, None (all configurations are included).
404+
Budget to select the configs. If no budget is given, all seeds are considered.
405+
By default None.
403406
seed: Optional[int]
404-
Considered seed. By default None (all configurations are included).
407+
Seed to select the configs. If no seed is given, all seeds are considered.
408+
By default None.
409+
statuses : Optional[Union[Status, List[Status]]]
410+
Only selected stati are considered. If no status is given, all stati are considered.
411+
By default None.
412+
405413
406414
Returns
407415
-------
@@ -426,6 +434,13 @@ def get_configs(
426434
if seed != trial.seed:
427435
continue
428436

437+
if statuses is not None:
438+
if isinstance(statuses, Status):
439+
statuses = [statuses]
440+
441+
if trial.status not in statuses:
442+
continue
443+
429444
if (config_id := trial.config_id) not in configs:
430445
config = self.get_config(config_id)
431446
configs[config_id] = config
@@ -673,8 +688,11 @@ def _process_costs(self, costs: List[float]) -> List[float]:
673688
return new_costs
674689

675690
def get_avg_costs(
676-
self, config_id: int, budget: Optional[Union[int, float]] = None
677-
) -> List[float]:
691+
self,
692+
config_id: int,
693+
budget: Optional[Union[int, float]] = None,
694+
statuses: Optional[Union[Status, List[Status]]] = None,
695+
) -> Tuple[List[float], List[float]]:
678696
"""
679697
Get average costs over all seeds for a config and budget.
680698
@@ -685,31 +703,40 @@ def get_avg_costs(
685703
budget : Optional[Union[int, float]]
686704
Budget to get the costs from the configuration id for. By default, None. If budget is
687705
None, the highest budget is chosen.
706+
statuses : Optional[Union[Status, List[Status]]]
707+
Only selected stati are considered. If no status is given, all stati are considered.
708+
By default None.
688709
689710
Returns
690711
-------
691712
List[float]
692713
List of average cost values for the given config_id and budget.
714+
List[float]
715+
List of std cost values for the given config_id and budget.
693716
"""
694717
objectives = self.get_objectives()
695718

696719
# Budget might not be evaluated
697-
config_costs = self.get_costs(config_id, budget)
720+
config_costs = self.get_costs(config_id, budget, statuses=statuses)
698721

699-
avg_costs = []
722+
avg_costs, std_costs = [], []
700723
for idx in range(len(objectives)):
701724
costs = [values[idx] for values in config_costs.values() if values[idx] is not None]
702725
avg_costs.append(float(np.mean(costs)))
703-
return avg_costs
726+
std_costs.append(float(np.std(costs)))
727+
return avg_costs, std_costs
704728

705729
def get_costs(
706-
self, config_id: int, budget: Optional[Union[int, float]] = None, seed: Optional[int] = None
730+
self,
731+
config_id: int,
732+
budget: Optional[Union[int, float]] = None,
733+
seed: Optional[int] = None,
734+
statuses: Optional[Union[Status, List[Status]]] = None,
707735
) -> Dict[int, List[float]]:
708736
"""
709737
Return the costs of a configuration.
710738
711-
In case of multi-objective, multiple costs are
712-
returned.
739+
In case of multi-objective, multiple costs are returned.
713740
714741
Parameters
715742
----------
@@ -721,6 +748,9 @@ def get_costs(
721748
seed : Optional[int], optional
722749
Seed to get the costs from the configuration id for. By default None. If no seed is
723750
given, all seeds are considered.
751+
statuses : Optional[Union[Status, List[Status]]]
752+
Only selected stati are considered. If no status is given, all stati are considered.
753+
By default None.
724754
725755
Returns
726756
-------
@@ -739,7 +769,7 @@ def get_costs(
739769

740770
if config_id not in self.configs:
741771
raise ValueError("Configuration id was not found.")
742-
costs = self.get_all_costs(budget=budget, seed=seed)
772+
costs = self.get_all_costs(budget=budget, seed=seed, statuses=statuses)
743773
if config_id not in costs:
744774
if seed is not None:
745775
raise RuntimeError(

0 commit comments

Comments
 (0)