|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | from concurrent.futures import ThreadPoolExecutor |
| 5 | +from pathlib import Path |
5 | 6 | from typing import TYPE_CHECKING |
6 | 7 |
|
7 | 8 | from loguru import logger |
@@ -34,6 +35,7 @@ def __init__( |
34 | 35 | embedding_base_url: str | None = None, |
35 | 36 | service_config: ServiceConfig | None = None, |
36 | 37 | parser: type[PydanticConfigParser] | None = None, |
| 38 | + working_dir: str | None = None, |
37 | 39 | config_path: str | None = None, |
38 | 40 | enable_logo: bool = True, |
39 | 41 | log_to_console: bool = True, |
@@ -74,13 +76,23 @@ def __init__( |
74 | 76 | self._update_section_config(kwargs, "memory_stores", **default_memory_store_config) |
75 | 77 | if default_file_watcher_config: |
76 | 78 | self._update_section_config(kwargs, "file_watchers", **default_file_watcher_config) |
77 | | - kwargs["enable_logo"] = enable_logo |
78 | | - kwargs["log_to_console"] = log_to_console |
| 79 | + |
| 80 | + kwargs.update( |
| 81 | + { |
| 82 | + "enable_logo": enable_logo, |
| 83 | + "log_to_console": log_to_console, |
| 84 | + "working_dir": working_dir, |
| 85 | + }, |
| 86 | + ) |
79 | 87 | logger.info(f"update with args: {input_args} kwargs: {kwargs}") |
80 | 88 | service_config = parser.parse_args(*input_args, **kwargs) |
81 | 89 |
|
82 | 90 | self.service_config: ServiceConfig = service_config |
83 | 91 | init_logger(log_to_console=self.service_config.log_to_console) |
| 92 | + logger.info(f"ReMe Config: {service_config.model_dump_json()}") |
| 93 | + |
| 94 | + if self.service_config.working_dir: |
| 95 | + Path(self.service_config.working_dir).mkdir(parents=True, exist_ok=True) |
84 | 96 |
|
85 | 97 | if self.service_config.enable_logo: |
86 | 98 | print_logo(service_config=self.service_config) |
@@ -147,41 +159,58 @@ def _build_flows(self): |
147 | 159 | async def start(self): |
148 | 160 | """Start the service context by initializing all configured components.""" |
149 | 161 | for name, config in self.service_config.llms.items(): |
150 | | - self.llms[name] = R.llms[config.backend](model_name=config.model_name, **config.model_extra) |
| 162 | + if config.backend not in R.llms: |
| 163 | + logger.warning(f"LLM backend {config.backend} is not supported.") |
| 164 | + else: |
| 165 | + self.llms[name] = R.llms[config.backend](model_name=config.model_name, **config.model_extra) |
151 | 166 |
|
152 | 167 | for name, config in self.service_config.embedding_models.items(): |
153 | | - self.embedding_models[name] = R.embedding_models[config.backend]( |
154 | | - model_name=config.model_name, |
155 | | - **config.model_extra, |
156 | | - ) |
| 168 | + if config.backend not in R.embedding_models: |
| 169 | + logger.warning(f"Embedding model backend {config.backend} is not supported.") |
| 170 | + else: |
| 171 | + self.embedding_models[name] = R.embedding_models[config.backend]( |
| 172 | + model_name=config.model_name, |
| 173 | + **config.model_extra, |
| 174 | + ) |
157 | 175 |
|
158 | 176 | for name, config in self.service_config.token_counters.items(): |
159 | | - self.token_counters[name] = R.token_counters[config.backend]( |
160 | | - model_name=config.model_name, |
161 | | - **config.model_extra, |
162 | | - ) |
| 177 | + if config.backend not in R.token_counters: |
| 178 | + logger.warning(f"Token counter backend {config.backend} is not supported.") |
| 179 | + else: |
| 180 | + self.token_counters[name] = R.token_counters[config.backend]( |
| 181 | + model_name=config.model_name, |
| 182 | + **config.model_extra, |
| 183 | + ) |
163 | 184 |
|
164 | 185 | for name, config in self.service_config.vector_stores.items(): |
165 | | - # Extract config dict and replace special fields with actual instances |
166 | | - config_dict = config.model_dump(exclude={"backend", "embedding_model"}) |
167 | | - config_dict["embedding_model"] = self.embedding_models[config.embedding_model] |
168 | | - config_dict["thread_pool"] = self.thread_pool |
169 | | - self.vector_stores[name] = R.vector_stores[config.backend](**config_dict) |
170 | | - await self.vector_stores[name].create_collection(config.collection_name) |
| 186 | + if config.backend not in R.vector_stores: |
| 187 | + logger.warning(f"Vector store backend {config.backend} is not supported.") |
| 188 | + else: |
| 189 | + # Extract config dict and replace special fields with actual instances |
| 190 | + config_dict = config.model_dump(exclude={"backend", "embedding_model"}) |
| 191 | + config_dict["embedding_model"] = self.embedding_models[config.embedding_model] |
| 192 | + config_dict["thread_pool"] = self.thread_pool |
| 193 | + self.vector_stores[name] = R.vector_stores[config.backend](**config_dict) |
| 194 | + await self.vector_stores[name].create_collection(config.collection_name) |
171 | 195 |
|
172 | 196 | for name, config in self.service_config.memory_stores.items(): |
173 | | - # Extract config dict and replace embedding_model string with actual instance |
174 | | - config_dict = config.model_dump(exclude={"backend", "embedding_model"}) |
175 | | - config_dict["embedding_model"] = self.embedding_models[config.embedding_model] |
176 | | - self.memory_stores[name] = R.memory_stores[config.backend](**config_dict) |
177 | | - await self.memory_stores[name].start() |
| 197 | + if config.backend not in R.memory_stores: |
| 198 | + logger.warning(f"Memory store backend {config.backend} is not supported.") |
| 199 | + else: |
| 200 | + # Extract config dict and replace embedding_model string with actual instance |
| 201 | + config_dict = config.model_dump(exclude={"backend", "embedding_model"}) |
| 202 | + config_dict["embedding_model"] = self.embedding_models[config.embedding_model] |
| 203 | + self.memory_stores[name] = R.memory_stores[config.backend](**config_dict) |
| 204 | + await self.memory_stores[name].start() |
178 | 205 |
|
179 | 206 | for name, config in self.service_config.file_watchers.items(): |
180 | | - # Extract config dict and replace memory_store string with actual instance |
181 | | - config_dict = config.model_dump(exclude={"backend", "memory_store"}) |
182 | | - config_dict["memory_store"] = self.memory_stores[config.memory_store] |
183 | | - self.file_watchers[name] = R.file_watchers[config.backend](**config_dict) |
184 | | - await self.file_watchers[name].start() |
| 207 | + if config.backend not in R.file_watchers: |
| 208 | + logger.warning(f"File watcher backend {config.backend} is not supported.") |
| 209 | + else: |
| 210 | + config_dict = config.model_dump(exclude={"backend", "memory_store"}) |
| 211 | + config_dict["memory_store"] = self.memory_stores[config.memory_store] |
| 212 | + self.file_watchers[name] = R.file_watchers[config.backend](**config_dict) |
| 213 | + await self.file_watchers[name].start() |
185 | 214 |
|
186 | 215 | if self.service_config.mcp_servers: |
187 | 216 | await self.prepare_mcp_servers() |
|
0 commit comments