Skip to content

Commit 0b1c3b5

Browse files
committed
add new implementation for data_collector module
1 parent 67879ef commit 0b1c3b5

File tree

4 files changed

+187
-0
lines changed

4 files changed

+187
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .search_data_collector import SearchDataCollector
2+
from .search_data_lookup import SearchDataLookup
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
from itertools import product
3+
from typing import Dict, List, Tuple, Any
4+
import time
5+
import os
6+
7+
8+
class GridGenerator:
9+
"""
10+
Generates parameter grids from search space definitions.
11+
12+
This class takes a search space dictionary and creates all possible
13+
combinations of parameters for grid search evaluation.
14+
"""
15+
16+
def __init__(self, search_space: Dict[str, List[Any]]):
17+
self.search_space = search_space
18+
self.param_names = list(search_space.keys())
19+
self.param_values = [search_space[name] for name in self.param_names]
20+
21+
def generate_grid(self) -> Tuple[np.ndarray, List[str]]:
22+
"""
23+
Generate complete parameter grid.
24+
25+
Returns:
26+
grid: numpy array of shape (n_combinations, n_params)
27+
param_names: list of parameter names
28+
"""
29+
# Create all combinations
30+
combinations = list(product(*self.param_values))
31+
grid = np.array(combinations)
32+
33+
return grid, self.param_names
34+
35+
def get_grid_size(self) -> int:
36+
"""Calculate total number of grid points."""
37+
return np.prod([len(values) for values in self.param_values])
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Dict, Any, List
2+
import numpy as np
3+
import time
4+
5+
from .grid_generator import GridGenerator
6+
7+
8+
class SearchDataCollector:
9+
"""
10+
Collects search data by evaluating objective functions on parameter grids.
11+
12+
This class handles the expensive computation of evaluating ML models
13+
across a parameter grid and saves the results for future use.
14+
"""
15+
16+
def __init__(self, objective_function, search_space: Dict[str, List[Any]]):
17+
self.objective_function = objective_function
18+
self.search_space = search_space
19+
self.grid_generator = GridGenerator(search_space)
20+
21+
def collect(self, verbose: bool = True) -> Dict[str, np.ndarray]:
22+
"""
23+
Evaluate objective function on entire grid.
24+
25+
Returns dictionary containing:
26+
- 'parameters': parameter grid
27+
- 'scores': objective function values
28+
- 'times': evaluation times in seconds
29+
- 'param_names': parameter names
30+
"""
31+
grid, param_names = self.grid_generator.generate_grid()
32+
n_points = len(grid)
33+
34+
scores = np.zeros(n_points)
35+
times = np.zeros(n_points)
36+
37+
for i, params in enumerate(grid):
38+
if verbose and i % 100 == 0:
39+
print(f"Evaluating point {i+1}/{n_points}")
40+
41+
# Convert to dictionary for objective function
42+
param_dict = {name: value for name, value in zip(param_names, params)}
43+
44+
# Time the evaluation
45+
start_time = time.perf_counter()
46+
scores[i] = self.objective_function(param_dict)
47+
times[i] = time.perf_counter() - start_time
48+
49+
return {
50+
"parameters": grid,
51+
"scores": scores,
52+
"times": times,
53+
"param_names": np.array(param_names, dtype="U"), # Unicode string array
54+
}
55+
56+
def save(self, filepath: str, verbose: bool = True):
57+
"""Collect data and save to file."""
58+
data = self.collect(verbose=verbose)
59+
60+
# Add metadata
61+
data["search_space_keys"] = np.array(list(self.search_space.keys()), dtype="U")
62+
data["search_space_sizes"] = np.array(
63+
[len(v) for v in self.search_space.values()]
64+
)
65+
66+
# Save as compressed numpy archive
67+
np.savez_compressed(filepath, **data)
68+
69+
if verbose:
70+
print(f"Saved search data to {filepath}")
71+
print(f"Total points: {len(data['scores'])}")
72+
print(f"Total evaluation time: {np.sum(data['times']):.2f} seconds")
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
from typing import Dict, Any, List
3+
4+
5+
class SearchDataLookup:
6+
"""
7+
Provides fast lookup of pre-computed objective function values.
8+
9+
This class loads search data from disk and provides O(1) lookup
10+
for parameter combinations that were evaluated during grid search.
11+
"""
12+
13+
def __init__(self, filepath: str):
14+
self.filepath = filepath
15+
self._load_data()
16+
self._build_lookup_table()
17+
18+
def _load_data(self):
19+
"""Load search data from file."""
20+
if not os.path.exists(self.filepath):
21+
raise FileNotFoundError(f"Search data file not found: {self.filepath}")
22+
23+
data = np.load(self.filepath)
24+
self.parameters = data["parameters"]
25+
self.scores = data["scores"]
26+
self.times = data["times"]
27+
self.param_names = data["param_names"]
28+
29+
# Reconstruct search space structure
30+
self.search_space_keys = data["search_space_keys"]
31+
self.search_space_sizes = data["search_space_sizes"]
32+
33+
def _build_lookup_table(self):
34+
"""Build hash table for fast parameter lookup."""
35+
self.lookup_table = {}
36+
37+
for i, params in enumerate(self.parameters):
38+
# Create hashable key from parameters
39+
key = tuple(params)
40+
self.lookup_table[key] = {"score": self.scores[i], "time": self.times[i]}
41+
42+
def evaluate(self, param_dict: Dict[str, Any]) -> float:
43+
"""
44+
Look up objective function value for given parameters.
45+
46+
Args:
47+
param_dict: Dictionary of parameter names to values
48+
49+
Returns:
50+
Objective function value
51+
52+
Raises:
53+
KeyError: If parameter combination not found in search data
54+
"""
55+
# Convert dict to tuple in correct order
56+
param_values = [param_dict[name] for name in self.param_names]
57+
key = tuple(param_values)
58+
59+
if key not in self.lookup_table:
60+
raise KeyError(
61+
f"Parameter combination not found in search data: {param_dict}"
62+
)
63+
64+
return self.lookup_table[key]["score"]
65+
66+
def get_evaluation_time(self, param_dict: Dict[str, Any]) -> float:
67+
"""Get the original evaluation time for given parameters."""
68+
param_values = [param_dict[name] for name in self.param_names]
69+
key = tuple(param_values)
70+
71+
if key not in self.lookup_table:
72+
raise KeyError(
73+
f"Parameter combination not found in search data: {param_dict}"
74+
)
75+
76+
return self.lookup_table[key]["time"]

0 commit comments

Comments
 (0)