Skip to content

Commit ceba007

Browse files
committed
improve error handling
1 parent 627047a commit ceba007

File tree

1 file changed

+79
-52
lines changed

1 file changed

+79
-52
lines changed

src/surfaces/search_data_collection/data_collector.py

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -130,56 +130,75 @@ def collect_search_data(self,
130130
if verbose:
131131
print(f"Found {existing_count} existing evaluations")
132132

133-
# Create wrapper function that handles dataset resolution
134-
original_objective = test_function.pure_objective_function
135-
wrapper_objective = self._create_dataset_wrapper(original_objective)
136-
test_function.pure_objective_function = wrapper_objective
137-
138133
# Collect evaluations
139134
start_time = time.time()
140135
collected_count = 0
141136
batch_evaluations = []
137+
evaluation_errors = 0
142138

143-
try:
144-
for i, parameters in enumerate(self.grid_generator.generate_grid_iterator(processed_search_space)):
145-
# Check if evaluation already exists
146-
existing = self.data_manager.lookup_evaluation(function_name, parameters)
147-
if existing is not None:
148-
continue
149-
150-
# Evaluate function with timing
151-
eval_start = time.time()
139+
for i, parameters in enumerate(self.grid_generator.generate_grid_iterator(processed_search_space)):
140+
# Check if evaluation already exists
141+
existing = self.data_manager.lookup_evaluation(function_name, parameters)
142+
if existing is not None:
143+
continue
144+
145+
# Resolve dataset string identifiers to functions before evaluation
146+
resolved_parameters = self._resolve_dataset_parameters(parameters)
147+
148+
# Evaluate function with timing
149+
eval_start = time.time()
150+
try:
151+
score = test_function.pure_objective_function(resolved_parameters)
152+
except Exception as e:
153+
evaluation_errors += 1
154+
error_msg = f"Evaluation failed for parameters {parameters}: {type(e).__name__}: {e}"
155+
if verbose:
156+
print(f"WARNING: {error_msg}")
157+
# Log detailed error information for debugging
158+
if verbose and hasattr(e, '__traceback__'):
159+
import traceback
160+
print(f"Full traceback: {traceback.format_exc()}")
161+
continue
162+
eval_time = time.time() - eval_start
163+
164+
# Add to batch (store the processed parameters for database consistency)
165+
batch_evaluations.append((parameters, score, eval_time))
166+
collected_count += 1
167+
168+
# Store batch when it reaches batch_size
169+
if len(batch_evaluations) >= batch_size:
152170
try:
153-
score = test_function.pure_objective_function(parameters)
154-
except Exception as e:
155-
if verbose:
156-
print(f"Error evaluating parameters {parameters}: {e}")
157-
continue
158-
eval_time = time.time() - eval_start
159-
160-
# Add to batch
161-
batch_evaluations.append((parameters, score, eval_time))
162-
collected_count += 1
163-
164-
# Store batch when it reaches batch_size
165-
if len(batch_evaluations) >= batch_size:
166171
self.data_manager.store_batch(function_name, parameter_names, batch_evaluations)
167172
batch_evaluations = []
168-
173+
except Exception as e:
174+
error_msg = f"Database storage failed: {type(e).__name__}: {e}"
169175
if verbose:
170-
progress = (i + 1) / total_combinations * 100
171-
print(f"Progress: {progress:.1f}% ({i + 1}/{total_combinations})")
172-
173-
# Store remaining evaluations
174-
if batch_evaluations:
175-
self.data_manager.store_batch(function_name, parameter_names, batch_evaluations)
176+
print(f"ERROR: {error_msg}")
177+
# This is a critical error - we should not continue
178+
raise RuntimeError(f"Critical database error during data collection: {error_msg}") from e
179+
180+
if verbose:
181+
progress = (i + 1) / total_combinations * 100
182+
print(f"Progress: {progress:.1f}% ({i + 1}/{total_combinations})")
176183

177-
except Exception as e:
178-
# Re-raise the exception after cleanup
179-
raise
180-
finally:
181-
# Always restore original objective function
182-
test_function.pure_objective_function = original_objective
184+
# Store remaining evaluations
185+
if batch_evaluations:
186+
try:
187+
self.data_manager.store_batch(function_name, parameter_names, batch_evaluations)
188+
except Exception as e:
189+
error_msg = f"Database storage failed for final batch: {type(e).__name__}: {e}"
190+
if verbose:
191+
print(f"ERROR: {error_msg}")
192+
raise RuntimeError(f"Critical database error during final storage: {error_msg}") from e
193+
194+
# Report evaluation errors if any occurred
195+
if evaluation_errors > 0:
196+
error_rate = (evaluation_errors / total_combinations) * 100
197+
warning_msg = f"WARNING: {evaluation_errors}/{total_combinations} evaluations failed ({error_rate:.1f}% error rate)"
198+
if verbose:
199+
print(warning_msg)
200+
if error_rate > 50: # More than 50% failures is concerning
201+
raise RuntimeError(f"High error rate in evaluations: {warning_msg}. Check your search space and function implementation.")
183202

184203
total_time = time.time() - start_time
185204

@@ -321,22 +340,30 @@ def _process_search_space(self, search_space: Dict[str, List[Any]]) -> Dict[str,
321340

322341
return processed
323342

324-
def _create_dataset_wrapper(self, original_objective: Callable) -> Callable:
325-
"""
326-
Create a wrapper function that resolves dataset string identifiers back to functions.
343+
def _resolve_dataset_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
327344
"""
328-
def wrapper(params):
329-
processed_params = params.copy()
345+
Resolve dataset string identifiers back to callable functions.
346+
347+
Args:
348+
parameters: Parameter dictionary that may contain dataset string identifiers
330349
331-
# Resolve dataset string back to function if present
332-
if "dataset" in processed_params:
333-
dataset_value = processed_params["dataset"]
334-
if isinstance(dataset_value, str):
335-
processed_params["dataset"] = self.dataset_registry.get(dataset_value)
350+
Returns:
351+
Parameter dictionary with dataset strings resolved to functions
336352
337-
return original_objective(processed_params)
353+
Raises:
354+
ValueError: If a dataset identifier cannot be resolved
355+
"""
356+
resolved_params = parameters.copy()
357+
358+
if "dataset" in resolved_params:
359+
dataset_value = resolved_params["dataset"]
360+
if isinstance(dataset_value, str):
361+
try:
362+
resolved_params["dataset"] = self.dataset_registry.get(dataset_value)
363+
except ValueError as e:
364+
raise ValueError(f"Failed to resolve dataset '{dataset_value}': {e}") from e
338365

339-
return wrapper
366+
return resolved_params
340367

341368
def register_dataset(self, name: str, dataset_func: Callable):
342369
"""Register a custom dataset function."""

0 commit comments

Comments
 (0)