Skip to content

Commit 139cbc7

Browse files
authored
Merge pull request #254 from mmalmrud/feat-custom-llm-client
Add LLM config option to allow the use of custom LLM clients
2 parents 7b673a7 + 4ad23fc commit 139cbc7

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

openevolve/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from dataclasses import dataclass, field
77
from pathlib import Path
8-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Any, Callable, Dict, List, Optional, Union
99

1010
import yaml
1111

@@ -19,6 +19,9 @@ class LLMModelConfig:
1919
api_key: Optional[str] = None
2020
name: str = None
2121

22+
# Custom LLM client
23+
init_client: Optional[Callable] = None
24+
2225
# Weight for model in ensemble
2326
weight: float = 1.0
2427

openevolve/llm/ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, models_cfg: List[LLMModelConfig]):
2121
self.models_cfg = models_cfg
2222

2323
# Initialize models from the configuration
24-
self.models = [OpenAILLM(model_cfg) for model_cfg in models_cfg]
24+
self.models = [model_cfg.init_client(model_cfg) if model_cfg.init_client else OpenAILLM(model_cfg) for model_cfg in models_cfg]
2525

2626
# Extract and normalize model weights
2727
self.weights = [model.weight for model in models_cfg]

tests/test_llm_ensemble.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
Tests for LLMEnsemble in openevolve.llm.ensemble
33
"""
44

5+
from typing import Any, Dict, List
56
import unittest
67
from openevolve.llm.ensemble import LLMEnsemble
78
from openevolve.config import LLMModelConfig
8-
9+
from openevolve.llm.base import LLMInterface
910

1011
class TestLLMEnsemble(unittest.TestCase):
1112
def test_weighted_sampling(self):
@@ -34,5 +35,32 @@ def test_weighted_sampling(self):
3435
self.assertEqual(len(sampled_models), len(models))
3536

3637

38+
39+
class TestEnsembleInit(unittest.TestCase):
40+
class MyCustomLLM(LLMInterface):
41+
def __init__(self, model, some_field):
42+
self.model = model
43+
self.some_field = some_field
44+
45+
async def generate(self, prompt: str, **kwargs) -> str:
46+
return "custom-generate"
47+
48+
async def generate_with_context(self, system_message: str, messages: List[Dict[str, str]], **kwargs) -> str:
49+
return "custom-generate-with-context"
50+
51+
def init_custom_llm(self, model_cfg):
52+
return self.MyCustomLLM(model=model_cfg.name, some_field="value")
53+
54+
def test_ensemble_initialization(self):
55+
models = [
56+
LLMModelConfig(name="a"),
57+
LLMModelConfig(name="b", init_client=self.init_custom_llm),
58+
]
59+
ensemble = LLMEnsemble(models)
60+
self.assertEqual(len(ensemble.models), len(models))
61+
self.assertEqual(ensemble.models[0].model, "a")
62+
self.assertEqual(ensemble.models[1].model, "b")
63+
self.assertEqual(ensemble.models[1].some_field, "value")
64+
3765
if __name__ == "__main__":
3866
unittest.main()

0 commit comments

Comments
 (0)