|
8 | 8 | from pathlib import Path |
9 | 9 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union |
10 | 10 |
|
| 11 | +import dacite |
11 | 12 | import yaml |
12 | 13 |
|
13 | 14 | if TYPE_CHECKING: |
@@ -418,46 +419,20 @@ def from_yaml(cls, path: Union[str, Path]) -> "Config": |
418 | 419 |
|
419 | 420 | @classmethod |
420 | 421 | def from_dict(cls, config_dict: Dict[str, Any]) -> "Config": |
421 | | - """Create configuration from a dictionary""" |
422 | | - # Handle nested configurations |
423 | | - config = Config() |
424 | | - |
425 | | - # Update top-level fields |
426 | | - for key, value in config_dict.items(): |
427 | | - if key not in ["llm", "prompt", "database", "evaluator", "evolution_trace"] and hasattr( |
428 | | - config, key |
429 | | - ): |
430 | | - setattr(config, key, value) |
431 | | - |
432 | | - # Update nested configs |
433 | | - if "llm" in config_dict: |
434 | | - llm_dict = config_dict["llm"] |
435 | | - if "models" in llm_dict: |
436 | | - llm_dict["models"] = [LLMModelConfig(**m) for m in llm_dict["models"]] |
437 | | - if "evaluator_models" in llm_dict: |
438 | | - llm_dict["evaluator_models"] = [ |
439 | | - LLMModelConfig(**m) for m in llm_dict["evaluator_models"] |
440 | | - ] |
441 | | - config.llm = LLMConfig(**llm_dict) |
442 | | - if "prompt" in config_dict: |
443 | | - config.prompt = PromptConfig(**config_dict["prompt"]) |
444 | | - if "database" in config_dict: |
445 | | - config.database = DatabaseConfig(**config_dict["database"]) |
446 | | - |
447 | | - # Ensure database inherits the random seed if not explicitly set |
448 | | - if config.database.random_seed is None and config.random_seed is not None: |
449 | | - config.database.random_seed = config.random_seed |
450 | | - if "evaluator" in config_dict: |
451 | | - config.evaluator = EvaluatorConfig(**config_dict["evaluator"]) |
452 | | - if "evolution_trace" in config_dict: |
453 | | - config.evolution_trace = EvolutionTraceConfig(**config_dict["evolution_trace"]) |
454 | 422 | if "diff_pattern" in config_dict: |
455 | | - # Validate it's a valid regex |
456 | 423 | try: |
457 | 424 | re.compile(config_dict["diff_pattern"]) |
458 | 425 | except re.error as e: |
459 | 426 | raise ValueError(f"Invalid regex pattern in diff_pattern: {e}") |
460 | | - config.diff_pattern = config_dict["diff_pattern"] |
| 427 | + |
| 428 | + config: Config = dacite.from_dict( |
| 429 | + data_class=cls, |
| 430 | + data=config_dict, |
| 431 | + config=dacite.Config(cast=[List, Union], forward_references={"LLMInterface": Any}), |
| 432 | + ) |
| 433 | + |
| 434 | + if config.database.random_seed is None and config.random_seed is not None: |
| 435 | + config.database.random_seed = config.random_seed |
461 | 436 |
|
462 | 437 | return config |
463 | 438 |
|
|
0 commit comments