11import abc
2+ import functools
23import inspect
34import typing as t
4- from dataclasses import dataclass , field
55from functools import lru_cache
66
77from loguru import logger
@@ -42,7 +42,7 @@ class Fixup(abc.ABC):
4242 """
4343
4444 @abc .abstractmethod
45- def can_fix (self , exception : Exception ) -> bool :
45+ def can_fix (self , exception : Exception ) -> bool | t . Literal [ "once" ] :
4646 """
4747 Check if the fixup can resolve the given exception if made active.
4848
@@ -68,10 +68,62 @@ def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
6868 ...
6969
7070
71- @dataclass
72- class Fixups :
73- available : list [Fixup ] = field (default_factory = list )
74- active : list [Fixup ] = field (default_factory = list )
71+ FixupCompatibleFunc = t .Callable [
72+ t .Concatenate [t .Any , t .Sequence [Message ], P ],
73+ t .Awaitable [R ],
74+ ]
75+
76+
77+ def with_fixups (
78+ * fixups : Fixup ,
79+ ) -> t .Callable [[FixupCompatibleFunc [P , R ]], FixupCompatibleFunc [P , R ]]:
80+ """
81+ Decorator that adds fixup retry logic with persistent state.
82+
83+ Args:
84+ fixups: Sequence of fixups to try
85+ """
86+ available_fixups : list [Fixup ] = list (fixups )
87+ active_fixups : list [Fixup ] = []
88+ once_fixups : list [Fixup ] = []
89+
90+ def decorator (func : FixupCompatibleFunc [P , R ]) -> FixupCompatibleFunc [P , R ]:
91+ @functools .wraps (func )
92+ async def wrapper (
93+ self : t .Any ,
94+ messages : t .Sequence [Message ],
95+ * args : P .args ,
96+ ** kwargs : P .kwargs ,
97+ ) -> R :
98+ nonlocal available_fixups , active_fixups
99+
100+ for fixup in [* active_fixups , * once_fixups ]:
101+ messages = fixup .fix (messages )
102+
103+ try :
104+ result = await func (self , messages , * args , ** kwargs )
105+ available_fixups = [* available_fixups , * once_fixups ]
106+ once_fixups .clear ()
107+ except Exception as e :
108+ for fixup in list (available_fixups ):
109+ if (can_fix := fixup .can_fix (e )) is False :
110+ continue
111+
112+ if can_fix == "once" :
113+ once_fixups .append (fixup )
114+ else :
115+ active_fixups .append (fixup )
116+ available_fixups .remove (fixup )
117+
118+ return await wrapper (self , messages , * args , ** kwargs )
119+
120+ raise
121+
122+ return result
123+
124+ return wrapper # type: ignore[return-value]
125+
126+ return decorator
75127
76128
77129# TODO: We also would like to support N-style
@@ -305,8 +357,6 @@ class Generator(BaseModel):
305357 _watch_callbacks : list ["WatchChatCallback | WatchCompletionCallback" ] = []
306358 _wrap : t .Callable [[CallableT ], CallableT ] | None = None
307359
308- _fixups : Fixups = Fixups ()
309-
310360 def to_identifier (self , params : GenerateParams | None = None ) -> str :
311361 """
312362 Converts the generator instance back into a rigging identifier string.
@@ -393,38 +443,6 @@ async def supports_function_calling(self) -> bool | None:
393443 """
394444 return None
395445
396- def _check_fixups (self , error : Exception ) -> bool :
397- """
398- Check if any fixer can handle this error.
399-
400- Args:
401- error: The error to be checked.
402-
403- Returns:
404- Whether a fixer was able to handle the error.
405- """
406- for fixup in self ._fixups .available [:]:
407- if fixup .can_fix (error ):
408- self ._fixups .active .append (fixup )
409- self ._fixups .available .remove (fixup )
410- return True
411- return False
412-
413- async def _apply_fixups (self , messages : t .Sequence [Message ]) -> t .Sequence [Message ]:
414- """
415- Apply all active fixups to the messages.
416-
417- Args:
418- messages: The messages to be fixed.
419-
420- Returns:
421- The fixed messages.
422- """
423- current_messages = messages
424- for fixup in self ._fixups .active :
425- current_messages = fixup .fix (current_messages )
426- return current_messages
427-
428446 async def generate_messages (
429447 self ,
430448 messages : t .Sequence [t .Sequence [Message ]],
0 commit comments