Skip to content

Commit aff8ba7

Browse files
committed
simplify Config.from_dict with dacite
1 parent e71e1e1 commit aff8ba7

File tree

2 files changed

+11
-35
lines changed

2 files changed

+11
-35
lines changed

openevolve/config.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1010

11+
import dacite
1112
import yaml
1213

1314
if TYPE_CHECKING:
@@ -418,46 +419,20 @@ def from_yaml(cls, path: Union[str, Path]) -> "Config":
418419

419420
@classmethod
420421
def from_dict(cls, config_dict: Dict[str, Any]) -> "Config":
421-
"""Create configuration from a dictionary"""
422-
# Handle nested configurations
423-
config = Config()
424-
425-
# Update top-level fields
426-
for key, value in config_dict.items():
427-
if key not in ["llm", "prompt", "database", "evaluator", "evolution_trace"] and hasattr(
428-
config, key
429-
):
430-
setattr(config, key, value)
431-
432-
# Update nested configs
433-
if "llm" in config_dict:
434-
llm_dict = config_dict["llm"]
435-
if "models" in llm_dict:
436-
llm_dict["models"] = [LLMModelConfig(**m) for m in llm_dict["models"]]
437-
if "evaluator_models" in llm_dict:
438-
llm_dict["evaluator_models"] = [
439-
LLMModelConfig(**m) for m in llm_dict["evaluator_models"]
440-
]
441-
config.llm = LLMConfig(**llm_dict)
442-
if "prompt" in config_dict:
443-
config.prompt = PromptConfig(**config_dict["prompt"])
444-
if "database" in config_dict:
445-
config.database = DatabaseConfig(**config_dict["database"])
446-
447-
# Ensure database inherits the random seed if not explicitly set
448-
if config.database.random_seed is None and config.random_seed is not None:
449-
config.database.random_seed = config.random_seed
450-
if "evaluator" in config_dict:
451-
config.evaluator = EvaluatorConfig(**config_dict["evaluator"])
452-
if "evolution_trace" in config_dict:
453-
config.evolution_trace = EvolutionTraceConfig(**config_dict["evolution_trace"])
454422
if "diff_pattern" in config_dict:
455-
# Validate it's a valid regex
456423
try:
457424
re.compile(config_dict["diff_pattern"])
458425
except re.error as e:
459426
raise ValueError(f"Invalid regex pattern in diff_pattern: {e}")
460-
config.diff_pattern = config_dict["diff_pattern"]
427+
428+
config: Config = dacite.from_dict(
429+
data_class=cls,
430+
data=config_dict,
431+
config=dacite.Config(cast=[List, Union], forward_references={"LLMInterface": Any}),
432+
)
433+
434+
if config.database.random_seed is None and config.random_seed is not None:
435+
config.database.random_seed = config.random_seed
461436

462437
return config
463438

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"numpy>=1.22.0",
1717
"tqdm>=4.64.0",
1818
"flask",
19+
"dacite>=1.9.2",
1920
]
2021

2122
[project.optional-dependencies]

0 commit comments

Comments
 (0)