@@ -194,6 +194,7 @@ class AICondition(Condition):
194194 model : BaseChatModel
195195 system_template : str
196196 inquiry_template : str
197+ retries : int = field (default = 3 )
197198 func : None = field (init = False , default = None )
198199 _rationale : str | None = field (init = False )
199200
@@ -221,18 +222,33 @@ def __call__(self, *args: Fact) -> bool:
221222
222223 system_msg = LiteralFormatter ().vformat (system_msg , [], values )
223224
224- # Invoke the LLM and get the result
225- result : BooleanDecision = self .chain .invoke ({"system_msg" : system_msg , "inquiry" : self .inquiry_template })
226- object .__setattr__ (self , "_rationale" , result .justification )
225+ # Retry the LLM invocation until it succeeds or the max retries is reached
226+ result : BooleanDecision
227+ for attempt in range (self .retries ):
228+ try :
229+ result = self .chain .invoke ({"system_msg" : system_msg , "inquiry" : self .inquiry_template })
230+ object .__setattr__ (self , "_rationale" , result .justification )
227231
228- if result .invalid_inquiry or result .result is None :
229- raise AIDecisionError (result .justification )
232+ if not (result .result is None or result .invalid_inquiry ):
233+ break # Successful result, exit retry loop
234+ else :
235+ logger .debug ("Retrying AI condition (attempt %s), reason: %s" , attempt + 1 , result .justification )
236+
237+ except Exception as e :
238+ if attempt == self .retries - 1 :
239+ raise # Raise the last exception if max retries reached
240+ logger .debug ("Retrying AI condition (attempt %s), reason: %s" , attempt + 1 , e )
241+
242+ if result .result is None or result .invalid_inquiry :
243+ reason = "invalid inquiry" if result .invalid_inquiry else result .justification
244+ msg = f"Failed after { self .retries } attempts; reason: { reason } "
245+ raise AIDecisionError (msg )
230246
231247 return not result .result if self .inverted else result .result
232248
233249
234250# TODO: Investigate how best to register tools for specific consitions
235- def ai_condition (model : BaseChatModel , inquiry : str ) -> AICondition :
251+ def ai_condition (model : BaseChatModel , inquiry : str , retries : int = 3 ) -> AICondition :
236252 # TODO: Optimize by precompiling regex and storing translation table globally
237253 # Find and referenced facts and replace braces with angle brackets
238254 facts = tuple (re .findall (r"\{([^}]+)\}" , inquiry ))
@@ -265,7 +281,9 @@ def ai_condition(model: BaseChatModel, inquiry: str) -> AICondition:
265281 prompt_template = ChatPromptTemplate .from_messages ([("system" , "{system_msg}" ), ("user" , user )])
266282 structured_model = model .with_structured_output (BooleanDecision )
267283 chain = prompt_template | structured_model
268- return AICondition (chain = chain , model = model , system_template = system , inquiry_template = inquiry , facts = facts )
284+ return AICondition (
285+ chain = chain , model = model , system_template = system , inquiry_template = inquiry , facts = facts , retries = retries
286+ )
269287
270288
271289@lru_cache (maxsize = 1 )
@@ -285,7 +303,7 @@ def _detect_default_model() -> BaseChatModel:
285303 raise ImportError (msg )
286304
287305
288- def condition (func : ConditionCallable | str , model : BaseChatModel | None = None ) -> Condition :
306+ def condition (func : ConditionCallable | str , retries : int = 3 , model : BaseChatModel | None = None ) -> Condition :
289307 """
290308 Creates a Condition object from a lambda or function. It performs limited static analysis of the code to ensure
291309 proper usage and discover the facts/attributes accessed by the condition. This allows the rule engine to track
@@ -332,7 +350,7 @@ def is_user_adult(user: User) -> bool:
332350 # AI condition assumed
333351 if not model :
334352 model = _detect_default_model ()
335- return ai_condition (model , func )
353+ return ai_condition (model , func , retries )
336354
337355
338356# TODO: Create a convenience function for creating OnFactChanged conditions
0 commit comments