1111
1212
1313@dataclass
14- class LLMConfig :
15- """Configuration for LLM models """
14+ class LLMModelConfig :
15+ """Configuration for a single LLM model """
1616
17- # Primary model
18- primary_model : str = "gemini-2.0-flash-lite"
19- primary_model_weight : float = 0.8
17+ # API configuration
18+ api_base : str = None
19+ api_key : Optional [str ] = None
20+ name : str = None
2021
21- # Secondary model
22- secondary_model : str = "gemini-2.0-flash"
23- secondary_model_weight : float = 0.2
22+ # Weight for model in ensemble
23+ weight : float = 1.0
24+
25+ # Generation parameters
26+ system_message : Optional [str ] = None
27+ temperature : float = None
28+ top_p : float = None
29+ max_tokens : int = None
30+
31+ # Request parameters
32+ timeout : int = None
33+ retries : int = None
34+ retry_delay : int = None
35+
36+
37+ @dataclass
38+ class LLMConfig (LLMModelConfig ):
39+ """Configuration for LLM models"""
2440
2541 # API configuration
2642 api_base : str = "https://api.openai.com/v1"
27- api_key : Optional [ str ] = None
43+ name : str = "gpt-4o"
2844
2945 # Generation parameters
46+ system_message : Optional [str ] = (
47+ "You are an expert coder helping to improve programs through evolution."
48+ )
3049 temperature : float = 0.7
3150 top_p : float = 0.95
3251 max_tokens : int = 4096
@@ -36,13 +55,69 @@ class LLMConfig:
3655 retries : int = 3
3756 retry_delay : int = 5
3857
58+ # n-model configuration for evolution LLM ensemble
59+ models : List [LLMModelConfig ] = field (default_factory = lambda : [LLMModelConfig ()])
60+
61+ # n-model configuration for evaluator LLM ensemble
62+ evaluator_models : List [LLMModelConfig ] = field (default_factory = lambda : [])
63+
64+ # Backwardes compatibility with primary_model(_weight) options
65+ primary_model : str = "gemini-2.0-flash-lite"
66+ primary_model_weight : float = 0.8
67+ secondary_model : str = "gemini-2.0-flash"
68+ secondary_model_weight : float = 0.2
69+
70+ def __post_init__ (self ):
71+ """Post-initialization to set up model configurations"""
72+ # Handle backward compatibility for primary_model(_weight) and secondary_model(_weight).
73+ if (self .primary_model or self .primary_model_weight ) and len (self .models ) < 1 :
74+ # Ensure we have a primary model
75+ self .models .append (LLMModelConfig ())
76+ if self .primary_model :
77+ self .models [0 ].name = self .primary_model
78+ if self .primary_model_weight :
79+ self .models [0 ].weight = self .primary_model_weight
80+
81+ if (self .secondary_model or self .secondary_model_weight ) and len (self .models ) < 2 :
82+ # Ensure we have a second model
83+ self .models .append (LLMModelConfig ())
84+ if self .secondary_model :
85+ self .models [1 ].name = self .secondary_model
86+ if self .secondary_model_weight :
87+ self .models [1 ].weight = self .secondary_model_weight
88+
89+ # If no evaluator models are defined, use the same models as for evolution
90+ if not self .evaluator_models or len (self .evaluator_models ) < 1 :
91+ self .evaluator_models = self .models .copy ()
92+
93+ # Update models with shared configuration values
94+ shared_config = {
95+ "api_base" : self .api_base ,
96+ "api_key" : self .api_key ,
97+ "temperature" : self .temperature ,
98+ "top_p" : self .top_p ,
99+ "max_tokens" : self .max_tokens ,
100+ "timeout" : self .timeout ,
101+ "retries" : self .retries ,
102+ "retry_delay" : self .retry_delay ,
103+ }
104+ self .update_model_params (shared_config )
105+
106+ def update_model_params (self , args : Dict [str , Any ], overwrite : bool = False ) -> None :
107+ """Update model parameters for all models"""
108+ for model in self .models + self .evaluator_models :
109+ for key , value in args .items ():
110+ if overwrite or getattr (model , key , None ) is None :
111+ setattr (model , key , value )
112+
39113
40114@dataclass
41115class PromptConfig :
42116 """Configuration for prompt generation"""
43117
44118 template_dir : Optional [str ] = None
45119 system_message : str = "You are an expert coder helping to improve programs through evolution."
120+ evaluator_system_message : str = """You are an expert code reviewer."""
46121
47122 # Number of examples to include in the prompt
48123 num_top_programs : int = 3
@@ -155,7 +230,14 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "Config":
155230
156231 # Update nested configs
157232 if "llm" in config_dict :
158- config .llm = LLMConfig (** config_dict ["llm" ])
233+ llm_dict = config_dict ["llm" ]
234+ if "models" in llm_dict :
235+ llm_dict ["models" ] = [LLMModelConfig (** m ) for m in llm_dict ["models" ]]
236+ if "evaluator_models" in llm_dict :
237+ llm_dict ["evaluator_models" ] = [
238+ LLMModelConfig (** m ) for m in llm_dict ["evaluator_models" ]
239+ ]
240+ config .llm = LLMConfig (** llm_dict )
159241 if "prompt" in config_dict :
160242 config .prompt = PromptConfig (** config_dict ["prompt" ])
161243 if "database" in config_dict :
@@ -176,10 +258,8 @@ def to_dict(self) -> Dict[str, Any]:
176258 "random_seed" : self .random_seed ,
177259 # Component configurations
178260 "llm" : {
179- "primary_model" : self .llm .primary_model ,
180- "primary_model_weight" : self .llm .primary_model_weight ,
181- "secondary_model" : self .llm .secondary_model ,
182- "secondary_model_weight" : self .llm .secondary_model_weight ,
261+ "models" : self .llm .models ,
262+ "evaluator_models" : self .llm .evaluator_models ,
183263 "api_base" : self .llm .api_base ,
184264 "temperature" : self .llm .temperature ,
185265 "top_p" : self .llm .top_p ,
@@ -191,6 +271,7 @@ def to_dict(self) -> Dict[str, Any]:
191271 "prompt" : {
192272 "template_dir" : self .prompt .template_dir ,
193273 "system_message" : self .prompt .system_message ,
274+ "evaluator_system_message" : self .prompt .evaluator_system_message ,
194275 "num_top_programs" : self .prompt .num_top_programs ,
195276 "num_diverse_programs" : self .prompt .num_diverse_programs ,
196277 "use_template_stochasticity" : self .prompt .use_template_stochasticity ,
@@ -245,16 +326,17 @@ def to_yaml(self, path: Union[str, Path]) -> None:
245326def load_config (config_path : Optional [Union [str , Path ]] = None ) -> Config :
246327 """Load configuration from a YAML file or use defaults"""
247328 if config_path and os .path .exists (config_path ):
248- return Config .from_yaml (config_path )
329+ config = Config .from_yaml (config_path )
330+ else :
331+ config = Config ()
332+
333+ # Use environment variables if available
334+ api_key = os .environ .get ("OPENAI_API_KEY" )
335+ api_base = os .environ .get ("OPENAI_API_BASE" , "https://api.openai.com/v1" )
249336
250- # Use environment variables if available
251- api_key = os .environ .get ("OPENAI_API_KEY" )
252- api_base = os .environ .get ("OPENAI_API_BASE" , "https://api.openai.com/v1" )
337+ config .llm .update_model_params ({"api_key" : api_key , "api_base" : api_base })
253338
254- config = Config ()
255- if api_key :
256- config .llm .api_key = api_key
257- if api_base :
258- config .llm .api_base = api_base
339+ # Make the system message available to the individual models, in case it is not provided from the prompt sampler
340+ config .llm .update_model_params ({"system_message" : config .prompt .system_message })
259341
260342 return config
0 commit comments