33
44import ast
55import inspect
6+ import logging
67import re
78import textwrap
8- import threading
99from ast import Attribute , Module , Name , NodeTransformer , NodeVisitor
10+ from collections import OrderedDict
1011from collections .abc import Callable
1112from dataclasses import dataclass , field
1213from functools import cached_property
1314from types import MappingProxyType
14- from typing import Any , TypeAliasType , get_type_hints
15+ from typing import Any , ClassVar , TypeAliasType , get_type_hints
1516
1617from vulcan_core .models import Fact , HasSource
1718
19+ logger = logging .getLogger (__name__ )
20+
1821
1922class ASTProcessingError (RuntimeError ):
2023 """Internal error encountered while processing AST."""
@@ -84,23 +87,74 @@ def visit_Attribute(self, node: Attribute): # noqa: N802
8487 return node
8588
8689
87- # Global index to cache and track lambda function positions within the same source lines.
88- # Tuple format: (source code, last processed index)
89- # TODO: Consider if a redesign is possible to have a single ASTProcessor handle the entire source line, perhaps eagerly
90- # processing all lambdas found in the line before the correspondign `condition` call.
91- _lambda_index_lock = threading .Lock ()
92- lambda_index : dict [Any , tuple [str , int | None ]] = {}
90+ @dataclass (slots = True )
91+ class LambdaSource :
92+ """Index entry for tracking the parsing position of lambda functions in source lines.
93+
94+ Attributes:
95+ source (str): The source code string containing lambda functions
96+ count (int): The number of lambda functions found in the source string.
97+ pos (int): The current parsing position within the source string.
98+ """
99+
100+ source : str
101+ count : int
102+ pos : int = field (default = 0 )
103+ in_use : bool = field (default = True )
93104
94105
95106@dataclass
96107class ASTProcessor [T : Callable ]:
108+ """
109+ This class extracts source code from functions or lambda expressions, parses them into
110+ Abstract Syntax Trees (AST), and performs various validations and transformations.
111+
112+ The processor validates that:
113+ - Functions have proper type hints for parameters and return types
114+ - All parameters are subclasses of Fact
115+ - No nested attribute access (e.g., X.y.z) is used
116+ - No async functions are processed
117+ - Lambda expressions do not contain parameters
118+ - No duplicate parameter types in function signatures
119+
120+ For lambda expressions, it automatically transforms attribute access patterns
121+ (e.g., ClassName.attribute) into parameterized functions for easier execution.
122+
123+ Note: This class is not thread-safe and should not be used concurrently across multiple threads.
124+
125+ Type Parameters:
126+ T: The type signature the processor is working with, this varies based on a condition or action being processed.
127+
128+ Attributes:
129+ func: The callable to process, a lambda or a function
130+ decorator: The decorator type that initiated the processing (e.g., `condition` or `action`)
131+ return_type: Expected return type for the callable
132+ source: Extracted source code of func (set during post-init)
133+ tree: Parsed AST of the source code (set during post-init)
134+ facts: Tuple of fact strings discovered in the callable (set during post-init)
135+
136+ Properties:
137+ is_lambda: True if the callable is a lambda expression
138+
139+ Raises:
140+ OSError: When source code cannot be extracted
141+ ScopeAccessError: When accessing undefined classes or using nested attributes
142+ CallableSignatureError: When function signature doesn't meet requirements
143+ NotAFactError: When parameter types are not Fact subclasses
144+ ASTProcessingError: When AST processing encounters internal errors
145+ """
146+
97147 func : T
98148 decorator : Callable
99149 return_type : type | TypeAliasType
100150 source : str = field (init = False )
101151 tree : Module = field (init = False )
102152 facts : tuple [str , ...] = field (init = False )
103153
154+ # Class-level tracking of lambdas across parsing calls to handle multiple lambdas on the same line
155+ _lambda_cache : ClassVar [OrderedDict [str , LambdaSource ]] = OrderedDict ()
156+ _MAX_LAMBDA_CACHE_SIZE : ClassVar [int ] = 1024
157+
104158 @cached_property
105159 def is_lambda (self ) -> bool :
106160 return isinstance (self .func , type (lambda : None )) and self .func .__name__ == "<lambda>"
@@ -113,30 +167,34 @@ def __post_init__(self):
113167 try :
114168 if self .is_lambda :
115169 # As of Python 3.12, there is no way to determine to which lambda self.func refers in an
116- # expression containing multiple lambdas. Therefore we use a global dict to track the index of each
170+ # expression containing multiple lambdas. Therefore we use a dict to track the index of each
117171 # lambda function encountered, as the order will correspond to the order of ASTProcessor
118172 # invocations for that line. An additional benefit is that we can also use this as a cache to
119173 # avoid re-reading the source code for lambda functions sharing the same line.
120- #
121- # The key for the index is a hash of the stack trace plus line number, which will be
122- # unique for each call of a list of lambdas on the same line.
123- frames = inspect .stack ()[1 :] # Exclude current frame
124- key = "" .join (f"{ f .filename } :{ f .lineno } " for f in frames )
125-
126- # Use a lock to ensure thread safety when accessing the global lambda index
127- with _lambda_index_lock :
128- index = lambda_index .get (key )
129- if index is None or index [1 ] is None :
130- self .source = self ._get_lambda_source ()
131- index = (self .source , 0 )
132- lambda_index [key ] = index
133- else :
134- self .source = index [0 ]
135- index = (self .source , index [1 ] + 1 )
136- lambda_index [key ] = index
174+ source_line = f"{ self .func .__code__ .co_filename } :{ self .func .__code__ .co_firstlineno } "
175+ lambda_src = self ._lambda_cache .get (source_line )
176+
177+ if lambda_src is None :
178+ self .source = self ._get_lambda_source ()
179+ lambda_count = self ._count_lambdas (self .source )
180+ lambda_src = LambdaSource (self .source , lambda_count )
181+ self ._lambda_cache [source_line ] = lambda_src
182+ self ._trim_lambda_cache ()
183+ else :
184+ self .source = lambda_src .source
185+ lambda_src .pos += 1
186+
187+ # Reset the position if it exceeds the count of lambda expressions
188+ if lambda_src .pos >= lambda_src .count :
189+ lambda_src .pos = 0
137190
138191 # Normalize the lambda source and extract the next lambda expression from the last index
139- self .source = self ._normalize_lambda_source (self .source , index [1 ])
192+ self .source = self ._normalize_lambda_source (self .source , lambda_src .pos )
193+
194+ # If done processing lambdas in the source, mark as not processing anymore
195+ if lambda_src .pos >= lambda_src .count - 1 :
196+ lambda_src .in_use = False
197+
140198 else :
141199 self .source = textwrap .dedent (inspect .getsource (self .func ))
142200 except OSError as e :
@@ -180,19 +238,48 @@ def __post_init__(self):
180238
181239 self .facts = tuple (facts )
182240
241+ def _trim_lambda_cache (self ) -> None :
242+ """Clean up lambda cache by removing oldest unused entries when cache size exceeds limit."""
243+ if len (self ._lambda_cache ) <= self ._MAX_LAMBDA_CACHE_SIZE :
244+ return
245+
246+ # Calculate how many entries to remove (excess + 20% buffer to avoid thrashing)
247+ excess_count = len (self ._lambda_cache ) - self ._MAX_LAMBDA_CACHE_SIZE
248+ buffer_count = int (self ._MAX_LAMBDA_CACHE_SIZE * 0.2 )
249+ target_count = excess_count + buffer_count
250+
251+ # Find and remove unused entries
252+ removed_count = 0
253+ for key in list (self ._lambda_cache ):
254+ if removed_count >= target_count :
255+ break
256+ if not self ._lambda_cache [key ].in_use :
257+ del self ._lambda_cache [key ]
258+ removed_count += 1
259+
260+ def _count_lambdas (self , source : str ) -> int :
261+ """Count lambda expressions in source code using AST parsing."""
262+ tree = ast .parse (source )
263+
264+ class LambdaCounter (ast .NodeVisitor ):
265+ def __init__ (self ):
266+ self .count = 0
267+
268+ def visit_Lambda (self , node ): # noqa: N802 - Case sensitive for AST
269+ self .count += 1
270+ self .generic_visit (node )
271+
272+ counter = LambdaCounter ()
273+ counter .visit (tree )
274+ return counter .count
275+
183276 def _get_lambda_source (self ) -> str :
184277 """Get single and multiline lambda source using AST parsing of the source file."""
185- try :
186- # Get caller frame to find the source file
187- frame = inspect .currentframe ()
188- while frame and frame .f_code .co_name != self .decorator .__name__ :
189- frame = frame .f_back
190-
191- if not frame or not frame .f_back :
192- return textwrap .dedent (inspect .getsource (self .func ))
278+ source = None
193279
194- caller_frame = frame .f_back
195- filename = caller_frame .f_code .co_filename
280+ try :
281+ # Get the source file and line number
282+ filename = self .func .__code__ .co_filename
196283 lambda_lineno = self .func .__code__ .co_firstlineno
197284
198285 # Read the source file
@@ -235,20 +322,25 @@ def visit_Lambda(self, node): # noqa: N802 - Case sensitive for AST
235322 end_line = i
236323 break
237324
238- return "\n " .join (lines [start_line : end_line + 1 ])
325+ source = "\n " .join (lines [start_line : end_line + 1 ])
239326
240327 except (OSError , SyntaxError , AttributeError ):
241- pass
242-
243- # Fallback to regular inspect.getsource
244- return textwrap .dedent (inspect .getsource (self .func ))
328+ logger .exception ("Failed to extract lambda source, attempting fallback." )
329+ source = inspect .getsource (self .func ).strip ()
245330
246- def _normalize_lambda_source (self , source : str , index : int ) -> str :
247- """Extracts just the lambda expression from source code."""
331+ if source is None or source == "" :
332+ msg = "Could not extract lambda source code"
333+ raise ASTProcessingError (msg )
248334
249- # Remove line endings and extra whitespace
335+ # Normalize the source: convert line breaks to spaces, collapse whitespace, and dedent
250336 source = re .sub (r"\r\n|\r|\n" , " " , source )
251337 source = re .sub (r"\s+" , " " , source )
338+ source = textwrap .dedent (source )
339+
340+ return source
341+
342+ def _normalize_lambda_source (self , source : str , index : int ) -> str :
343+ """Extracts just the lambda expression from source code."""
252344
253345 # Find the Nth lambda occurrence using generator expression
254346 positions = [i for i in range (len (source ) - 5 ) if source [i : i + 6 ] == "lambda" ]
0 commit comments