|
| 1 | +import hashlib |
| 2 | +import json |
| 3 | + |
1 | 4 | from collections.abc import Generator |
| 5 | +from typing import ClassVar |
2 | 6 |
|
3 | 7 | import openai |
4 | 8 |
|
|
13 | 17 |
|
14 | 18 |
|
15 | 19 | class OpenAILLM(BaseLLM): |
16 | | - """OpenAI LLM class.""" |
| 20 | + """OpenAI LLM class with singleton pattern.""" |
| 21 | + |
| 22 | + _instances: ClassVar[dict] = {} # Class variable to store instances |
| 23 | + |
| 24 | + def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM": |
| 25 | + config_hash = cls._get_config_hash(config) |
| 26 | + |
| 27 | + if config_hash not in cls._instances: |
| 28 | + logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}") |
| 29 | + instance = super().__new__(cls) |
| 30 | + cls._instances[config_hash] = instance |
| 31 | + else: |
| 32 | + logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}") |
| 33 | + |
| 34 | + return cls._instances[config_hash] |
17 | 35 |
|
18 | 36 | def __init__(self, config: OpenAILLMConfig): |
| 37 | + # Avoid duplicate initialization |
| 38 | + if hasattr(self, "_initialized"): |
| 39 | + return |
| 40 | + |
19 | 41 | self.config = config |
20 | 42 | self.client = openai.Client(api_key=config.api_key, base_url=config.api_base) |
| 43 | + self._initialized = True |
| 44 | + logger.info("OpenAI LLM instance initialized") |
| 45 | + |
| 46 | + @classmethod |
| 47 | + def _get_config_hash(cls, config: OpenAILLMConfig) -> str: |
| 48 | + """Generate hash value of configuration""" |
| 49 | + config_dict = config.model_dump() |
| 50 | + config_str = json.dumps(config_dict, sort_keys=True) |
| 51 | + return hashlib.md5(config_str.encode()).hexdigest() |
| 52 | + |
| 53 | + @classmethod |
| 54 | + def clear_cache(cls): |
| 55 | + """Clear all cached instances""" |
| 56 | + cls._instances.clear() |
| 57 | + logger.info("OpenAI LLM instance cache cleared") |
21 | 58 |
|
22 | 59 | def generate(self, messages: MessageList) -> str: |
23 | 60 | """Generate a response from OpenAI LLM.""" |
@@ -71,15 +108,50 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non |
71 | 108 |
|
72 | 109 |
|
73 | 110 | class AzureLLM(BaseLLM): |
74 | | - """Azure OpenAI LLM class.""" |
| 111 | + """Azure OpenAI LLM class with singleton pattern.""" |
| 112 | + |
| 113 | + _instances: ClassVar[dict] = {} # Class variable to store instances |
| 114 | + |
| 115 | + def __new__(cls, config: AzureLLMConfig): |
| 116 | + # Generate hash value of config as cache key |
| 117 | + config_hash = cls._get_config_hash(config) |
| 118 | + |
| 119 | + if config_hash not in cls._instances: |
| 120 | + logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}") |
| 121 | + instance = super().__new__(cls) |
| 122 | + cls._instances[config_hash] = instance |
| 123 | + else: |
| 124 | + logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}") |
| 125 | + |
| 126 | + return cls._instances[config_hash] |
75 | 127 |
|
76 | 128 | def __init__(self, config: AzureLLMConfig): |
| 129 | + # Avoid duplicate initialization |
| 130 | + if hasattr(self, "_initialized"): |
| 131 | + return |
| 132 | + |
77 | 133 | self.config = config |
78 | 134 | self.client = openai.AzureOpenAI( |
79 | 135 | azure_endpoint=config.base_url, |
80 | 136 | api_version=config.api_version, |
81 | 137 | api_key=config.api_key, |
82 | 138 | ) |
| 139 | + self._initialized = True |
| 140 | + logger.info("Azure LLM instance initialized") |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def _get_config_hash(cls, config: AzureLLMConfig) -> str: |
| 144 | + """Generate hash value of configuration""" |
| 145 | + # Convert config to dict and sort to ensure consistency |
| 146 | + config_dict = config.model_dump() |
| 147 | + config_str = json.dumps(config_dict, sort_keys=True) |
| 148 | + return hashlib.md5(config_str.encode()).hexdigest() |
| 149 | + |
| 150 | + @classmethod |
| 151 | + def clear_cache(cls): |
| 152 | + """Clear all cached instances""" |
| 153 | + cls._instances.clear() |
| 154 | + logger.info("Azure LLM instance cache cleared") |
83 | 155 |
|
84 | 156 | def generate(self, messages: MessageList) -> str: |
85 | 157 | """Generate a response from Azure OpenAI LLM.""" |
|
0 commit comments