|
9 | 9 | - Group: Can group and manage a group of runs. |
10 | 10 | """ |
11 | 11 |
|
12 | | -from typing import Any, Dict, Iterator, List, Optional, Tuple |
| 12 | +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union |
13 | 13 |
|
14 | 14 | from copy import deepcopy |
15 | 15 |
|
16 | 16 | import numpy as np |
17 | 17 |
|
18 | 18 | from deepcave.runs import AbstractRun, NotMergeableError, check_equality |
| 19 | +from deepcave.runs.objective import Objective |
19 | 20 | from deepcave.utils.hash import string_to_hash |
20 | 21 |
|
21 | 22 |
|
@@ -286,74 +287,91 @@ def get_model(self, config_id: int) -> Optional[Any]: |
286 | 287 | run_id, config_id = self._original_config_mapping[config_id] |
287 | 288 | return self.runs[run_id].get_model(config_id) |
288 | 289 |
|
289 | | - # Types dont match superclass |
290 | | - def get_trajectory(self, *args, **kwargs): # type: ignore |
| 290 | + def get_trajectory( |
| 291 | + self, |
| 292 | + objective: Objective, |
| 293 | + budget: Optional[Union[int, float]] = None, |
| 294 | + seed: Optional[int] = None, |
| 295 | + ) -> Tuple[List[float], List[float], List[float], List[int], List[int]]: |
291 | 296 | """ |
292 | | - Calculate the trajectory of the given objective and budget. |
293 | | -
|
294 | | - This includes the times, the mean costs, and the standard deviation of the costs. |
| 297 | + Calculate the trajectory of the given objective, budget, and seed. |
295 | 298 |
|
296 | 299 | Parameters |
297 | 300 | ---------- |
298 | | - *args |
299 | | - Should be the objective to calculate the trajectory from. |
300 | | - **kwargs |
301 | | - Should be the budget to calculate the trajectory for. |
| 301 | + objective : Objective |
| 302 | + Objective to calculate the trajectory for. |
| 303 | + budget : Optional[Union[int, float]] |
| 304 | + Budget to calculate the trajectory for. If no budget is given, then the highest budget |
| 305 | + is chosen. By default None. |
| 306 | + seed : Optional[int], optional |
| 307 | + Seed to calculate the trajectory for. If no seed is given, then all seeds are |
| 308 | + considered. By default None. |
302 | 309 |
|
303 | 310 | Returns |
304 | 311 | ------- |
305 | | - times : List[float] |
306 | | - Times of the trajectory. |
307 | | - costs_mean : List[float] |
308 | | - Costs of the trajectory. |
309 | | - costs_std : List[float] |
310 | | - Standard deviation of the costs of the trajectory. |
311 | | - ids : List[int] |
312 | | - The "global" ids of the selected trial. |
313 | | - config_ids : List[int] |
314 | | - The configuration ids of the selected trials. |
| 312 | + Tuple[List[float], List[float], List[float], List[int], List[int]] |
| 313 | + times : List[float] |
| 314 | + Times of the trajectory. |
| 315 | + costs_mean : List[float] |
| 316 | + Costs of the trajectory. |
| 317 | + costs_std : List[float] |
| 318 | + Standard deviation of the costs of the trajectory. This is particularly useful for |
| 319 | + grouped runs. |
| 320 | + ids : List[int] |
| 321 | + The "global" ids of the selected trials. |
| 322 | + config_ids : List[int] |
| 323 | + Config ids of the selected trials. |
315 | 324 | """ |
316 | | - # Cache costs |
317 | | - run_costs = [] |
318 | | - run_times = [] |
319 | | - |
320 | | - # All x values on which y values are needed |
321 | | - all_times = [] |
322 | | - |
323 | | - for _, run in enumerate(self.runs): |
324 | | - times, costs_mean, _, _, _ = run.get_trajectory(*args, **kwargs) |
325 | | - |
326 | | - # Cache s.t. calculate it is not calculated multiple times |
327 | | - run_costs.append(costs_mean) |
328 | | - run_times.append(times) |
329 | | - |
330 | | - # Add all times |
331 | | - # Standard deviation needs to be calculated on all times |
332 | | - for time in times: |
333 | | - if time not in all_times: |
334 | | - all_times.append(time) |
335 | | - |
336 | | - all_times.sort() |
337 | | - |
338 | | - # Now look for corresponding y values |
339 | | - all_costs = [] |
340 | | - |
341 | | - for time in all_times: |
342 | | - y = [] |
343 | | - |
344 | | - # Iterate over all runs |
345 | | - for costs, times in zip(run_costs, run_times): |
346 | | - # Find closest x value |
347 | | - idx = min(range(len(times)), key=lambda i: abs(times[i] - time)) |
348 | | - y.append(costs[idx]) |
349 | | - |
350 | | - all_costs.append(y) |
351 | | - |
352 | | - # Make numpy arrays |
353 | | - all_costs_array = np.array(all_costs) |
354 | | - |
355 | | - times = all_times |
356 | | - costs_mean = np.mean(all_costs_array, axis=1) |
357 | | - costs_std = np.std(all_costs_array, axis=1) |
358 | | - |
359 | | - return times, list(costs_mean), list(costs_std), [], [] |
| 325 | + if budget is None: |
| 326 | + budget = self.get_highest_budget() |
| 327 | + |
| 328 | + costs_mean = [] |
| 329 | + costs_std = [] |
| 330 | + ids = [] |
| 331 | + config_ids = [] |
| 332 | + times = [] |
| 333 | + |
| 334 | + order = [] |
| 335 | + |
| 336 | + # Sort self.history by end-time |
| 337 | + for id, trial in enumerate(self.history): |
| 338 | + order.append((id, trial.end_time)) |
| 339 | + order.sort(key=lambda tup: tup[1]) |
| 340 | + |
| 341 | + # Important: Objective can be minimized or maximized |
| 342 | + if objective.optimize == "lower": |
| 343 | + current_cost = np.inf |
| 344 | + else: |
| 345 | + current_cost = -np.inf |
| 346 | + |
| 347 | + # Iterate over the history ordered by end-time and calculate the current incumbent |
| 348 | + for i, (id, _) in enumerate(order): |
| 349 | + trial = self.history[id] |
| 350 | + |
| 351 | + # Get the incumbent over all trials up to this point |
| 352 | + try: |
| 353 | + _, cost = self.get_incumbent( |
| 354 | + objectives=objective, |
| 355 | + budget=budget, |
| 356 | + seed=seed, |
| 357 | + selected_ids=[selected_id for selected_id, _ in order[: i + 1]], |
| 358 | + ) |
| 359 | + except RuntimeError: |
| 360 | + continue |
| 361 | + |
| 362 | + # Now it's important to check whether the cost was minimized or maximized |
| 363 | + if objective.optimize == "lower": |
| 364 | + improvement = cost < current_cost |
| 365 | + else: |
| 366 | + improvement = cost > current_cost |
| 367 | + |
| 368 | + if improvement: |
| 369 | + current_cost = cost |
| 370 | + |
| 371 | + costs_mean.append(cost) |
| 372 | + costs_std.append(0.0) |
| 373 | + times.append(trial.end_time) |
| 374 | + ids.append(id) |
| 375 | + config_ids.append(trial.config_id) |
| 376 | + |
| 377 | + return times, costs_mean, costs_std, ids, config_ids |
0 commit comments