Skip to content

Commit 41ee368

Browse files
author
Sarah Krebs
committed
Correct LPI importance tests
1 parent ebff8a6 commit 41ee368

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
- Fix errors due to changing inputs before runselection (#64).
2222
- For fANOVA, remove constant hyperparameters from configspace (#9).
2323
- For PCP, show hyperparameters with highest importance closest to the cost (i.e. right).
24+
- Add init files to all test directories.
25+
- Correct LPI importance tests.
2426

2527
## Version-Updates
2628
- Black version from 23.1.0 to 23.3.0

tests/test_evaluators/test_lpi.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from deepcave.evaluators.lpi import LPI as Evaluator
44
from deepcave.runs import AbstractRun
5-
from deepcave.runs.converters.smac3v1 import SMAC3v1Run
5+
from deepcave.runs.converters.smac3v2 import SMAC3v2Run
66

77

88
class TestLPI(unittest.TestCase):
99
def setUp(self):
1010
# Initiate run here
11-
self.run: AbstractRun = SMAC3v1Run.from_path("logs/SMAC3v1/mlp/run_1")
11+
self.run: AbstractRun = SMAC3v2Run.from_path("logs/SMAC3v2/mlp/run_1")
1212
self.hp_names = self.run.configspace.get_hyperparameter_names()
1313
self.evaluator = Evaluator(self.run)
1414

@@ -17,10 +17,28 @@ def test(self):
1717
objective = self.run.get_objective(0)
1818

1919
# Calculate
20-
self.evaluator.calculate(objective, budget)
21-
importance_dict = self.evaluator.get_importances(self.hp_names)
20+
self.evaluator.calculate(objective, budget, seed=0)
21+
importances = self.evaluator.get_importances(self.hp_names)
2222

23-
print(importance_dict)
23+
self.evaluator.calculate(objective, budget, seed=42)
24+
importances2 = self.evaluator.get_importances(self.hp_names)
25+
26+
# Different seed: Different results
27+
assert importances["batch_size"][1] != importances2["batch_size"][1]
28+
29+
def test_seed(self):
30+
budget = self.run.get_budget(0)
31+
objective = self.run.get_objective(0)
32+
33+
# Calculate
34+
self.evaluator.calculate(objective, budget, seed=0)
35+
importances = self.evaluator.get_importances(self.hp_names)
36+
37+
self.evaluator.calculate(objective, budget, seed=0)
38+
importances2 = self.evaluator.get_importances(self.hp_names)
39+
40+
# Same seed: Same results
41+
assert importances["batch_size"][1] == importances2["batch_size"][1]
2442

2543

2644
if __name__ == "__main__":

0 commit comments

Comments
 (0)