Skip to content

Commit 558a91b

Browse files
authored
Bump to 1.1.0 - Added ability to choose model provider when declaring conditions (#6)
1 parent 9ba090d commit 558a91b

File tree

7 files changed

+319
-243
lines changed

7 files changed

+319
-243
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"editor.suggest.localityBonus": true,
9595
"editor.inlineSuggest.showToolbar": "onHover",
9696

97+
"chat.agent.enabled": true,
9798
"github.copilot.chat.generateTests.codeLens": true,
98-
"github.copilot.nextEditSuggestions.enabled": true,
99+
"github.copilot.nextEditSuggestions.enabled": false
99100
}

poetry.lock

Lines changed: 232 additions & 231 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ requires = ["poetry-core"]
2020
build-backend = "poetry.core.masonry.api"
2121

2222
[tool.poetry]
23-
version = "1.0.0" # Update manually, or use plugin
23+
version = "1.1.0" # Update manually, or use plugin
2424
packages = [{ include = "vulcan_core", from="src" }]
2525
requires-poetry = "~2.1.1"
2626
classifiers = [
@@ -84,6 +84,9 @@ pydantic = "~2.10.6"
8484
langchain = { version = "~0.3.20", optional = true }
8585
langchain-openai = { version = "~0.3.9", optional = true }
8686

87+
[tool.poetry.extras]
88+
openai = ["langchain", "langchain-openai"]
89+
8790
[project.optional-dependencies]
8891
openai = ["langchain", "langchain-openai"]
8992

src/vulcan_core/ast_utils.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,35 @@ def _extract_lambda_source(self) -> str:
152152
raise ASTProcessingError(msg)
153153

154154
# The source includes the entire line of code (e.g., assignment and condition() call)
155-
# We need to parse parentheses to extract just the lambda expression, handling any
156-
# nested parentheses in the lambda's body correctly
155+
# We need to extract just the lambda expression, handling nested structures correctly
157156
source = self.source[lambda_start:]
157+
158+
# Track depth of various brackets to ensure we don't split inside valid nested structures apart from trailing
159+
# arguments within the condition() call
158160
paren_level = 0
161+
bracket_level = 0
162+
brace_level = 0
163+
159164
for i, char in enumerate(source):
160165
if char == "(":
161166
paren_level += 1
162-
elif char == ")" and paren_level > 0:
163-
paren_level -= 1
164-
elif char == ")" and paren_level == 0:
167+
elif char == ")":
168+
if paren_level > 0:
169+
paren_level -= 1
170+
elif paren_level == 0: # End of expression in a function call
171+
return source[:i]
172+
elif char == "[":
173+
bracket_level += 1
174+
elif char == "]":
175+
if bracket_level > 0:
176+
bracket_level -= 1
177+
elif char == "{":
178+
brace_level += 1
179+
elif char == "}":
180+
if brace_level > 0:
181+
brace_level -= 1
182+
# Only consider comma as a separator when not inside any brackets
183+
elif char == "," and paren_level == 0 and bracket_level == 0 and brace_level == 0:
165184
return source[:i]
166185

167186
return source

src/vulcan_core/conditions.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from abc import abstractmethod
99
from dataclasses import dataclass, field
1010
from enum import Enum, auto
11+
from functools import lru_cache
1112
from string import Formatter
1213
from typing import TYPE_CHECKING
1314

1415
from langchain.prompts import ChatPromptTemplate
15-
from langchain_openai import ChatOpenAI
1616
from pydantic import BaseModel, Field
1717

1818
from vulcan_core.actions import ASTProcessor
@@ -22,6 +22,11 @@
2222
from langchain_core.language_models import BaseChatModel
2323
from langchain_core.runnables import RunnableSerializable
2424

25+
import importlib.util
26+
import logging
27+
28+
logger = logging.getLogger(__name__)
29+
2530

2631
@dataclass(frozen=True, slots=True)
2732
class Expression(DeclaresFacts):
@@ -182,6 +187,7 @@ def get_field(self, field_name, args, kwargs):
182187
@dataclass(frozen=True, slots=True)
183188
class AICondition(Condition):
184189
chain: RunnableSerializable
190+
model: BaseChatModel
185191
system_template: str
186192
inquiry_template: str
187193
func: None = field(init=False, default=None)
@@ -238,13 +244,23 @@ def ai_condition(model: BaseChatModel, inquiry: str) -> AICondition:
238244
)
239245
structured_model = model.with_structured_output(BooleanDecision)
240246
chain = prompt_template | structured_model
241-
return AICondition(chain=chain, system_template=system, inquiry_template=inquiry, facts=facts)
247+
return AICondition(chain=chain, model=model, system_template=system, inquiry_template=inquiry, facts=facts)
242248

243249

244-
default_model = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=100) # type: ignore[call-arg] - pyright can't see the args for some reason
250+
@lru_cache(maxsize=1)
251+
def _detect_default_model() -> BaseChatModel:
252+
# TODO: Expand this to detect other providers
253+
if importlib.util.find_spec("langchain_openai"):
254+
from langchain_openai import ChatOpenAI
255+
256+
logger.debug("Using OpenAI as the default LLM model provider.")
257+
return ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=100) # type: ignore[call-arg] - pyright can't see the args for some reason
258+
else:
259+
msg = "Unable to import a default LLM provider. Please install `vulcan_core` with the approriate extras package or specify your custom model explicitly."
260+
raise ImportError(msg)
245261

246262

247-
def condition(func: ConditionCallable | str) -> Condition:
263+
def condition(func: ConditionCallable | str, model: BaseChatModel | None = None) -> Condition:
248264
"""
249265
Creates a Condition object from a lambda or function. It performs limited static analysis of the code to ensure
250266
proper usage and discover the facts/attributes accessed by the condition. This allows the rule engine to track
@@ -284,10 +300,14 @@ def is_user_adult(user: User) -> bool:
284300
"""
285301

286302
if not isinstance(func, str):
303+
# Logic condition assumed, ignore kwargs
287304
processed = ASTProcessor[ConditionCallable](func, condition, bool)
288305
return Condition(processed.facts, processed.func)
289306
else:
290-
return ai_condition(default_model, func)
307+
# AI condition assumed
308+
if not model:
309+
model = _detect_default_model()
310+
return ai_condition(model, func)
291311

292312

293313
# TODO: Create a convenience function for creating OnFactChanged conditions

src/vulcan_core/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,6 @@ def __iter__(self) -> str:
258258

259259
def __len__(self) -> int:
260260
raise NotImplementedError
261+
262+
def __str__(self) -> str:
263+
return f"RetrieverAdapter(search_type={self.store.search_type})"

tests/core/test_conditions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright 2025 Latchfield Technologies http://latchfield.com
33

4+
from functools import partial
5+
from unittest.mock import Mock
6+
47
import pytest
58
from langchain_core.language_models import BaseChatModel
69
from langchain_openai import ChatOpenAI
@@ -108,8 +111,34 @@ def test_ai_simple_condition_true(model: BaseChatModel, fact_a_instance: FactA,
108111
assert cond(fact_a_instance, fact_b_instance) is True
109112

110113

114+
@pytest.mark.integration
111115
def test_ai_missing_fact(model: BaseChatModel):
112116
# TODO: Determine the difference between tool calls and non-tool calls
113117
# We shouldn't raise an exception if tools are being used
114118
with pytest.raises(MissingFactError):
115119
ai_condition(model, "Is the sky blue?")
120+
121+
122+
@pytest.mark.integration
123+
def test_aicondition_with_custom_model(model: BaseChatModel, fact_a_instance: FactA, fact_b_instance: FactB):
124+
cond = condition(f"Are {FactA.feature} and {FactB.feature} both on the same planet?", model=model)
125+
126+
assert set(cond.facts) == {"FactA.feature", "FactB.feature"}
127+
assert cond(fact_a_instance, fact_b_instance) is False
128+
129+
130+
def test_condition_with_custom_model(foo_instance: Foo, bar_instance: Bar):
131+
model = Mock()
132+
condition(lambda: Foo.baz and Bar.biz, model=model)
133+
134+
135+
def test_aicondition_model_override():
136+
model1 = Mock()
137+
model2 = Mock()
138+
139+
custom_condition = partial(condition, model=model1)
140+
141+
cond1 = custom_condition(f"Are {FactA.feature} and {FactB.feature} both on the same planet?")
142+
cond2 = custom_condition(f"Are {FactA.feature} and {FactB.feature} both on the same planet?", model=model2)
143+
144+
assert cond1.model != cond2.model # type: ignore

0 commit comments

Comments
 (0)