Skip to content

Commit 54a70e4

Browse files
author
Sarah Krebs
committed
Adapt ablation and importance api examples
1 parent 4b9c577 commit 54a70e4

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

examples/api/ablation_paths.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313

1414
if __name__ == "__main__":
1515
# Instantiate the run
16-
run = DeepCAVERun.from_path(Path("logs/DeepCAVE/minimal/run_2"))
16+
run = DeepCAVERun.from_path(Path("logs/DeepCAVE/minimal/run_1"))
1717

1818
objective_id1 = run.get_objective_ids()[0]
19-
objective_id2 = run.get_objective_ids()[1]
20-
budget_id = run.get_budget_ids()[1]
19+
objective_id2 = None # replace with run.get_objective_ids()[1] for multi-objective importance
20+
budget_id = run.get_budget_ids()[0]
2121

2222
# Instantiate the plugin
2323
plugin = AblationPaths()
2424
inputs = plugin.generate_inputs(
2525
objective_id1=objective_id1,
26-
objective_id2=None, # replace with objective_id2 for multi-objective importance
26+
objective_id2=objective_id2,
2727
budget_id=budget_id,
2828
n_hps=100,
2929
n_trees=100,
@@ -35,4 +35,5 @@
3535
# Finally, you can load the figure. Here, the filter variables play a role.
3636
figure1, figure2 = plugin.load_outputs(run, inputs, outputs)
3737
figure1.write_image("examples/api/ablation_paths_performance.png", scale=2.0)
38-
figure2.write_image("examples/api/ablation_paths_improvement.png", scale=2.0)
38+
if not objective_id2:
39+
figure2.write_image("examples/api/ablation_paths_improvement.png", scale=2.0)

examples/api/importances.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
run = DeepCAVERun.from_path(Path("logs/DeepCAVE/minimal/run_2"))
1717

1818
objective_id1 = run.get_objective_ids()[0]
19-
objective_id2 = run.get_objective_ids()[1]
19+
objective_id2 = None # replace with run.get_objective_ids()[1] for multi-objective importance
2020
budget_ids = run.get_budget_ids()
2121

2222
# Instantiate the plugin
2323
plugin = Importances()
2424
inputs = plugin.generate_inputs(
2525
hyperparameter_names=list(run.configspace.keys()),
2626
objective_id1=objective_id1,
27-
objective_id2=None, # replace with objective_id2 for multi-objective importance
27+
objective_id2=objective_id2,
2828
budget_ids=budget_ids,
2929
method="global",
3030
n_hps=3,

0 commit comments

Comments
 (0)