diff --git a/openevolve/config.py b/openevolve/config.py index dbcb9cef..64553b55 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -5,7 +5,7 @@ import os from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import yaml @@ -19,6 +19,9 @@ class LLMModelConfig: api_key: Optional[str] = None name: str = None + # Custom LLM client + init_client: Optional[Callable] = None + # Weight for model in ensemble weight: float = 1.0 diff --git a/openevolve/llm/ensemble.py b/openevolve/llm/ensemble.py index f5db7be2..749b46aa 100644 --- a/openevolve/llm/ensemble.py +++ b/openevolve/llm/ensemble.py @@ -21,7 +21,7 @@ def __init__(self, models_cfg: List[LLMModelConfig]): self.models_cfg = models_cfg # Initialize models from the configuration - self.models = [OpenAILLM(model_cfg) for model_cfg in models_cfg] + self.models = [model_cfg.init_client(model_cfg) if model_cfg.init_client else OpenAILLM(model_cfg) for model_cfg in models_cfg] # Extract and normalize model weights self.weights = [model.weight for model in models_cfg] diff --git a/tests/test_llm_ensemble.py b/tests/test_llm_ensemble.py index 72e9c134..f4a2f3ff 100644 --- a/tests/test_llm_ensemble.py +++ b/tests/test_llm_ensemble.py @@ -2,10 +2,11 @@ Tests for LLMEnsemble in openevolve.llm.ensemble """ +from typing import Any, Dict, List import unittest from openevolve.llm.ensemble import LLMEnsemble from openevolve.config import LLMModelConfig - +from openevolve.llm.base import LLMInterface class TestLLMEnsemble(unittest.TestCase): def test_weighted_sampling(self): @@ -34,5 +35,32 @@ def test_weighted_sampling(self): self.assertEqual(len(sampled_models), len(models)) + +class TestEnsembleInit(unittest.TestCase): + class MyCustomLLM(LLMInterface): + def __init__(self, model, some_field): + self.model = model + self.some_field = some_field + + async def generate(self, prompt: str, **kwargs) -> str: + return "custom-generate" + + async def generate_with_context(self, system_message: str, messages: List[Dict[str, str]], **kwargs) -> str: + return "custom-generate-with-context" + + def init_custom_llm(self, model_cfg): + return self.MyCustomLLM(model=model_cfg.name, some_field="value") + + def test_ensemble_initialization(self): + models = [ + LLMModelConfig(name="a"), + LLMModelConfig(name="b", init_client=self.init_custom_llm), + ] + ensemble = LLMEnsemble(models) + self.assertEqual(len(ensemble.models), len(models)) + self.assertEqual(ensemble.models[0].model, "a") + self.assertEqual(ensemble.models[1].model, "b") + self.assertEqual(ensemble.models[1].some_field, "value") + if __name__ == "__main__": unittest.main()