Skip to content

Commit 963d6a0

Browse files
segflyCopilot
andauthored
Fixed lambda parser tracking for looping cases (#47)
- Fixed lambda parser tracking for looping cases - Added lambda cache limit and pruning - Improved bad test case Signed-off-by: Nicholas Pace <segfly@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 4faf63d commit 963d6a0

File tree

5 files changed

+148
-49
lines changed

5 files changed

+148
-49
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ requires = ["poetry-core"]
2020
build-backend = "poetry.core.masonry.api"
2121

2222
[tool.poetry]
23-
version = "1.1.3" # Update manually, or use plugin
23+
version = "1.1.4" # Update manually, or use plugin
2424
packages = [{ include = "vulcan_core", from="src" }]
2525
requires-poetry = "~2.1.1"
2626
classifiers = [

src/vulcan_core/ast_utils.py

Lines changed: 137 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@
33

44
import ast
55
import inspect
6+
import logging
67
import re
78
import textwrap
8-
import threading
99
from ast import Attribute, Module, Name, NodeTransformer, NodeVisitor
10+
from collections import OrderedDict
1011
from collections.abc import Callable
1112
from dataclasses import dataclass, field
1213
from functools import cached_property
1314
from types import MappingProxyType
14-
from typing import Any, TypeAliasType, get_type_hints
15+
from typing import Any, ClassVar, TypeAliasType, get_type_hints
1516

1617
from vulcan_core.models import Fact, HasSource
1718

19+
logger = logging.getLogger(__name__)
20+
1821

1922
class ASTProcessingError(RuntimeError):
2023
"""Internal error encountered while processing AST."""
@@ -84,23 +87,74 @@ def visit_Attribute(self, node: Attribute): # noqa: N802
8487
return node
8588

8689

87-
# Global index to cache and track lambda function positions within the same source lines.
88-
# Tuple format: (source code, last processed index)
89-
# TODO: Consider if a redesign is possible to have a single ASTProcessor handle the entire source line, perhaps eagerly
90-
# processing all lambdas found in the line before the correspondign `condition` call.
91-
_lambda_index_lock = threading.Lock()
92-
lambda_index: dict[Any, tuple[str, int | None]] = {}
90+
@dataclass(slots=True)
91+
class LambdaSource:
92+
"""Index entry for tracking the parsing position of lambda functions in source lines.
93+
94+
Attributes:
95+
source (str): The source code string containing lambda functions
96+
count (int): The number of lambda functions found in the source string.
97+
pos (int): The current parsing position within the source string.
98+
"""
99+
100+
source: str
101+
count: int
102+
pos: int = field(default=0)
103+
in_use: bool = field(default=True)
93104

94105

95106
@dataclass
96107
class ASTProcessor[T: Callable]:
108+
"""
109+
This class extracts source code from functions or lambda expressions, parses them into
110+
Abstract Syntax Trees (AST), and performs various validations and transformations.
111+
112+
The processor validates that:
113+
- Functions have proper type hints for parameters and return types
114+
- All parameters are subclasses of Fact
115+
- No nested attribute access (e.g., X.y.z) is used
116+
- No async functions are processed
117+
- Lambda expressions do not contain parameters
118+
- No duplicate parameter types in function signatures
119+
120+
For lambda expressions, it automatically transforms attribute access patterns
121+
(e.g., ClassName.attribute) into parameterized functions for easier execution.
122+
123+
Note: This class is not thread-safe and should not be used concurrently across multiple threads.
124+
125+
Type Parameters:
126+
T: The type signature the processor is working with, this varies based on a condition or action being processed.
127+
128+
Attributes:
129+
func: The callable to process, a lambda or a function
130+
decorator: The decorator type that initiated the processing (e.g., `condition` or `action`)
131+
return_type: Expected return type for the callable
132+
source: Extracted source code of func (set during post-init)
133+
tree: Parsed AST of the source code (set during post-init)
134+
facts: Tuple of fact strings discovered in the callable (set during post-init)
135+
136+
Properties:
137+
is_lambda: True if the callable is a lambda expression
138+
139+
Raises:
140+
OSError: When source code cannot be extracted
141+
ScopeAccessError: When accessing undefined classes or using nested attributes
142+
CallableSignatureError: When function signature doesn't meet requirements
143+
NotAFactError: When parameter types are not Fact subclasses
144+
ASTProcessingError: When AST processing encounters internal errors
145+
"""
146+
97147
func: T
98148
decorator: Callable
99149
return_type: type | TypeAliasType
100150
source: str = field(init=False)
101151
tree: Module = field(init=False)
102152
facts: tuple[str, ...] = field(init=False)
103153

154+
# Class-level tracking of lambdas across parsing calls to handle multiple lambdas on the same line
155+
_lambda_cache: ClassVar[OrderedDict[str, LambdaSource]] = OrderedDict()
156+
_MAX_LAMBDA_CACHE_SIZE: ClassVar[int] = 1024
157+
104158
@cached_property
105159
def is_lambda(self) -> bool:
106160
return isinstance(self.func, type(lambda: None)) and self.func.__name__ == "<lambda>"
@@ -113,30 +167,34 @@ def __post_init__(self):
113167
try:
114168
if self.is_lambda:
115169
# As of Python 3.12, there is no way to determine to which lambda self.func refers in an
116-
# expression containing multiple lambdas. Therefore we use a global dict to track the index of each
170+
# expression containing multiple lambdas. Therefore we use a dict to track the index of each
117171
# lambda function encountered, as the order will correspond to the order of ASTProcessor
118172
# invocations for that line. An additional benefit is that we can also use this as a cache to
119173
# avoid re-reading the source code for lambda functions sharing the same line.
120-
#
121-
# The key for the index is a hash of the stack trace plus line number, which will be
122-
# unique for each call of a list of lambdas on the same line.
123-
frames = inspect.stack()[1:] # Exclude current frame
124-
key = "".join(f"{f.filename}:{f.lineno}" for f in frames)
125-
126-
# Use a lock to ensure thread safety when accessing the global lambda index
127-
with _lambda_index_lock:
128-
index = lambda_index.get(key)
129-
if index is None or index[1] is None:
130-
self.source = self._get_lambda_source()
131-
index = (self.source, 0)
132-
lambda_index[key] = index
133-
else:
134-
self.source = index[0]
135-
index = (self.source, index[1] + 1)
136-
lambda_index[key] = index
174+
source_line = f"{self.func.__code__.co_filename}:{self.func.__code__.co_firstlineno}"
175+
lambda_src = self._lambda_cache.get(source_line)
176+
177+
if lambda_src is None:
178+
self.source = self._get_lambda_source()
179+
lambda_count = self._count_lambdas(self.source)
180+
lambda_src = LambdaSource(self.source, lambda_count)
181+
self._lambda_cache[source_line] = lambda_src
182+
self._trim_lambda_cache()
183+
else:
184+
self.source = lambda_src.source
185+
lambda_src.pos += 1
186+
187+
# Reset the position if it exceeds the count of lambda expressions
188+
if lambda_src.pos >= lambda_src.count:
189+
lambda_src.pos = 0
137190

138191
# Normalize the lambda source and extract the next lambda expression from the last index
139-
self.source = self._normalize_lambda_source(self.source, index[1])
192+
self.source = self._normalize_lambda_source(self.source, lambda_src.pos)
193+
194+
# If done processing lambdas in the source, mark as not processing anymore
195+
if lambda_src.pos >= lambda_src.count - 1:
196+
lambda_src.in_use = False
197+
140198
else:
141199
self.source = textwrap.dedent(inspect.getsource(self.func))
142200
except OSError as e:
@@ -180,19 +238,48 @@ def __post_init__(self):
180238

181239
self.facts = tuple(facts)
182240

241+
def _trim_lambda_cache(self) -> None:
242+
"""Clean up lambda cache by removing oldest unused entries when cache size exceeds limit."""
243+
if len(self._lambda_cache) <= self._MAX_LAMBDA_CACHE_SIZE:
244+
return
245+
246+
# Calculate how many entries to remove (excess + 20% buffer to avoid thrashing)
247+
excess_count = len(self._lambda_cache) - self._MAX_LAMBDA_CACHE_SIZE
248+
buffer_count = int(self._MAX_LAMBDA_CACHE_SIZE * 0.2)
249+
target_count = excess_count + buffer_count
250+
251+
# Find and remove unused entries
252+
removed_count = 0
253+
for key in list(self._lambda_cache):
254+
if removed_count >= target_count:
255+
break
256+
if not self._lambda_cache[key].in_use:
257+
del self._lambda_cache[key]
258+
removed_count += 1
259+
260+
def _count_lambdas(self, source: str) -> int:
261+
"""Count lambda expressions in source code using AST parsing."""
262+
tree = ast.parse(source)
263+
264+
class LambdaCounter(ast.NodeVisitor):
265+
def __init__(self):
266+
self.count = 0
267+
268+
def visit_Lambda(self, node): # noqa: N802 - Case sensitive for AST
269+
self.count += 1
270+
self.generic_visit(node)
271+
272+
counter = LambdaCounter()
273+
counter.visit(tree)
274+
return counter.count
275+
183276
def _get_lambda_source(self) -> str:
184277
"""Get single and multiline lambda source using AST parsing of the source file."""
185-
try:
186-
# Get caller frame to find the source file
187-
frame = inspect.currentframe()
188-
while frame and frame.f_code.co_name != self.decorator.__name__:
189-
frame = frame.f_back
190-
191-
if not frame or not frame.f_back:
192-
return textwrap.dedent(inspect.getsource(self.func))
278+
source = None
193279

194-
caller_frame = frame.f_back
195-
filename = caller_frame.f_code.co_filename
280+
try:
281+
# Get the source file and line number
282+
filename = self.func.__code__.co_filename
196283
lambda_lineno = self.func.__code__.co_firstlineno
197284

198285
# Read the source file
@@ -235,20 +322,25 @@ def visit_Lambda(self, node): # noqa: N802 - Case sensitive for AST
235322
end_line = i
236323
break
237324

238-
return "\n".join(lines[start_line : end_line + 1])
325+
source = "\n".join(lines[start_line : end_line + 1])
239326

240327
except (OSError, SyntaxError, AttributeError):
241-
pass
242-
243-
# Fallback to regular inspect.getsource
244-
return textwrap.dedent(inspect.getsource(self.func))
328+
logger.exception("Failed to extract lambda source, attempting fallback.")
329+
source = inspect.getsource(self.func).strip()
245330

246-
def _normalize_lambda_source(self, source: str, index: int) -> str:
247-
"""Extracts just the lambda expression from source code."""
331+
if source is None or source == "":
332+
msg = "Could not extract lambda source code"
333+
raise ASTProcessingError(msg)
248334

249-
# Remove line endings and extra whitespace
335+
# Normalize the source: convert line breaks to spaces, collapse whitespace, and dedent
250336
source = re.sub(r"\r\n|\r|\n", " ", source)
251337
source = re.sub(r"\s+", " ", source)
338+
source = textwrap.dedent(source)
339+
340+
return source
341+
342+
def _normalize_lambda_source(self, source: str, index: int) -> str:
343+
"""Extracts just the lambda expression from source code."""
252344

253345
# Find the Nth lambda occurrence using generator expression
254346
positions = [i for i in range(len(source) - 5) if source[i : i + 6] == "lambda"]

tests/core/fixtures/rule_loading.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ class Foo(Fact):
99

1010

1111
def load_simple_rule(engine: RuleEngine):
12+
# This rule tests for repeated parsing of the same lambda expression, plus potential errors with naive parsing.
1213
engine.rule(
1314
name="test_rule",
14-
when=condition(lambda: Foo.baz),
15+
when=condition(lambda: Foo.baz and "lambda:" != None), # Keep this comment to test parser counting: lambda:
1516
then=action(partial(Foo, bol=False)),
1617
)

tests/core/test_conditions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_invert_condition(foo_instance: Foo):
125125
assert inverted.facts == cond.facts
126126
assert inverted(foo_instance) == (not cond(foo_instance))
127127

128+
128129
# https://github.com/latchfield/vulcan-core/issues/30
129130
def test_short_circuit_condition(foo_instance: Foo):
130131
true_condition = condition(lambda: True)
@@ -143,6 +144,7 @@ def test_short_circuit_condition(foo_instance: Foo):
143144
with pytest.raises(AssertionError):
144145
cond3()
145146

147+
146148
# https://github.com/latchfield/vulcan-core/issues/28
147149
def test_mixed_conditions(foo_instance: Foo, bar_instance: Bar):
148150
mycond = condition(lambda: Foo.baz)
@@ -151,6 +153,7 @@ def test_mixed_conditions(foo_instance: Foo, bar_instance: Bar):
151153
result = compound_cond(foo_instance, bar_instance)
152154
assert result is False
153155

156+
154157
# https://github.com/latchfield/vulcan-core/issues/28
155158
def test_multiple_lambdas(foo_instance: Foo, bar_instance: Bar):
156159
compound_cond1 = condition(lambda: Foo.baz) & condition(lambda: Bar.biz)
@@ -161,6 +164,7 @@ def test_multiple_lambdas(foo_instance: Foo, bar_instance: Bar):
161164
assert result1 is False
162165
assert result2 is True
163166

167+
164168
# https://github.com/latchfield/vulcan-core/issues/28
165169
def test_mixed_conditions_decorator(foo_instance: Foo, bar_instance: Bar):
166170
@condition
@@ -180,6 +184,7 @@ def test_non_boolean_question(custom_model: BaseChatModel, fact_a_instance: Fact
180184
with pytest.raises(AIDecisionError):
181185
cond(fact_a_instance)
182186

187+
183188
# https://github.com/latchfield/vulcan-core/issues/32
184189
@pytest.mark.integration
185190
def test_literal_placeholder_interpretation(fact_a_instance: FactA):

tests/core/test_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ def test_simple_rule(engine: RuleEngine):
7777

7878

7979
# https://github.com/latchfield/vulcan-core/issues/44
80+
# Updated for https://github.com/latchfield/vulcan-core/issues/46
8081
def test_lambda_reparsing(engine: RuleEngine):
81-
load_simple_rule(engine)
82-
load_simple_rule(engine)
82+
for _ in range(2):
83+
load_simple_rule(engine)
8384

8485

8586
def test_same_fact_multiple_attributes_lambda(engine: RuleEngine):

0 commit comments

Comments
 (0)