|
8 | 8 | from abc import abstractmethod |
9 | 9 | from dataclasses import dataclass, field |
10 | 10 | from enum import Enum, auto |
| 11 | +from functools import lru_cache |
11 | 12 | from string import Formatter |
12 | 13 | from typing import TYPE_CHECKING |
13 | 14 |
|
14 | 15 | from langchain.prompts import ChatPromptTemplate |
15 | | -from langchain_openai import ChatOpenAI |
16 | 16 | from pydantic import BaseModel, Field |
17 | 17 |
|
18 | 18 | from vulcan_core.actions import ASTProcessor |
|
22 | 22 | from langchain_core.language_models import BaseChatModel |
23 | 23 | from langchain_core.runnables import RunnableSerializable |
24 | 24 |
|
| 25 | +import importlib.util |
| 26 | +import logging |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
25 | 30 |
|
26 | 31 | @dataclass(frozen=True, slots=True) |
27 | 32 | class Expression(DeclaresFacts): |
@@ -182,6 +187,7 @@ def get_field(self, field_name, args, kwargs): |
182 | 187 | @dataclass(frozen=True, slots=True) |
183 | 188 | class AICondition(Condition): |
184 | 189 | chain: RunnableSerializable |
| 190 | + model: BaseChatModel |
185 | 191 | system_template: str |
186 | 192 | inquiry_template: str |
187 | 193 | func: None = field(init=False, default=None) |
@@ -238,13 +244,23 @@ def ai_condition(model: BaseChatModel, inquiry: str) -> AICondition: |
238 | 244 | ) |
239 | 245 | structured_model = model.with_structured_output(BooleanDecision) |
240 | 246 | chain = prompt_template | structured_model |
241 | | - return AICondition(chain=chain, system_template=system, inquiry_template=inquiry, facts=facts) |
| 247 | + return AICondition(chain=chain, model=model, system_template=system, inquiry_template=inquiry, facts=facts) |
242 | 248 |
|
243 | 249 |
|
244 | | -default_model = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=100) # type: ignore[call-arg] - pyright can't see the args for some reason |
| 250 | +@lru_cache(maxsize=1) |
| 251 | +def _detect_default_model() -> BaseChatModel: |
| 252 | + # TODO: Expand this to detect other providers |
| 253 | + if importlib.util.find_spec("langchain_openai"): |
| 254 | + from langchain_openai import ChatOpenAI |
| 255 | + |
| 256 | + logger.debug("Using OpenAI as the default LLM model provider.") |
| 257 | + return ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=100) # type: ignore[call-arg] - pyright can't see the args for some reason |
| 258 | + else: |
| 259 | + msg = "Unable to import a default LLM provider. Please install `vulcan_core` with the approriate extras package or specify your custom model explicitly." |
| 260 | + raise ImportError(msg) |
245 | 261 |
|
246 | 262 |
|
247 | | -def condition(func: ConditionCallable | str) -> Condition: |
| 263 | +def condition(func: ConditionCallable | str, model: BaseChatModel | None = None) -> Condition: |
248 | 264 | """ |
249 | 265 | Creates a Condition object from a lambda or function. It performs limited static analysis of the code to ensure |
250 | 266 | proper usage and discover the facts/attributes accessed by the condition. This allows the rule engine to track |
@@ -284,10 +300,14 @@ def is_user_adult(user: User) -> bool: |
284 | 300 | """ |
285 | 301 |
|
286 | 302 | if not isinstance(func, str): |
| 303 | + # Logic condition assumed, ignore kwargs |
287 | 304 | processed = ASTProcessor[ConditionCallable](func, condition, bool) |
288 | 305 | return Condition(processed.facts, processed.func) |
289 | 306 | else: |
290 | | - return ai_condition(default_model, func) |
| 307 | + # AI condition assumed |
| 308 | + if not model: |
| 309 | + model = _detect_default_model() |
| 310 | + return ai_condition(model, func) |
291 | 311 |
|
292 | 312 |
|
293 | 313 | # TODO: Create a convenience function for creating OnFactChanged conditions |
|
0 commit comments