|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import abc |
3 | 4 | import inspect |
4 | 5 | import typing as t |
| 6 | +from dataclasses import dataclass, field |
5 | 7 |
|
6 | 8 | from loguru import logger |
7 | 9 | from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator |
|
21 | 23 |
|
22 | 24 | CallableT = t.TypeVar("CallableT", bound=t.Callable[..., t.Any]) |
23 | 25 |
|
| 26 | +T = t.TypeVar("T") |
| 27 | + |
24 | 28 | # Global provider map |
25 | 29 |
|
26 | 30 |
|
27 | 31 | @t.runtime_checkable |
28 | 32 | class LazyGenerator(t.Protocol): |
29 | | - def __call__(self) -> type[Generator]: ... |
| 33 | + def __call__(self) -> type[Generator]: |
| 34 | + ... |
30 | 35 |
|
31 | 36 |
|
32 | 37 | g_providers: dict[str, type[Generator] | LazyGenerator] = {} |
33 | 38 |
|
| 39 | +# Fixups |
| 40 | + |
| 41 | + |
| 42 | +class Fixup(abc.ABC): |
| 43 | + """ |
| 44 | + Base class for fixups that apply on message sequences to correct errors. |
| 45 | + """ |
| 46 | + |
| 47 | + @abc.abstractmethod |
| 48 | + def can_fix(self, exception: Exception) -> bool: |
| 49 | + """ |
| 50 | + Check if the fixup can resolve the given exception if made active. |
| 51 | +
|
| 52 | + Args: |
| 53 | + exception: The exception to be checked. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + Whether the fixup can handle the exception. |
| 57 | + """ |
| 58 | + ... |
| 59 | + |
| 60 | + @abc.abstractmethod |
| 61 | + def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]: |
| 62 | + """ |
| 63 | + Process a sequence of messages to fix them. |
| 64 | +
|
| 65 | + Args: |
| 66 | + messages: The messages to be fixed. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + The fixed messages. |
| 70 | + """ |
| 71 | + ... |
| 72 | + |
| 73 | + |
| 74 | +@dataclass |
| 75 | +class Fixups: |
| 76 | + available: list[Fixup] = field(default_factory=list) |
| 77 | + active: list[Fixup] = field(default_factory=list) |
| 78 | + |
34 | 79 |
|
35 | 80 | # TODO: We also would like to support N-style |
36 | 81 | # parallel generation eventually -> need to |
@@ -251,6 +296,8 @@ class Generator(BaseModel): |
251 | 296 | _watch_callbacks: list[WatchCallbacks] = [] |
252 | 297 | _wrap: t.Callable[[CallableT], CallableT] | None = None |
253 | 298 |
|
| 299 | + _fixups: Fixups = Fixups() |
| 300 | + |
254 | 301 | def to_identifier(self, params: GenerateParams | None = None) -> str: |
255 | 302 | """ |
256 | 303 | Converts the generator instance back into a rigging identifier string. |
@@ -323,6 +370,38 @@ def wrap(self, func: t.Callable[[CallableT], CallableT] | None) -> Self: |
323 | 370 | self._wrap = func # type: ignore [assignment] |
324 | 371 | return self |
325 | 372 |
|
| 373 | + def _check_fixups(self, error: Exception) -> bool: |
| 374 | + """ |
| 375 | + Check if any fixer can handle this error. |
| 376 | +
|
| 377 | + Args: |
| 378 | + error: The error to be checked. |
| 379 | +
|
| 380 | + Returns: |
| 381 | + Whether a fixer was able to handle the error. |
| 382 | + """ |
| 383 | + for fixup in self._fixups.available[:]: |
| 384 | + if fixup.can_fix(error): |
| 385 | + self._fixups.active.append(fixup) |
| 386 | + self._fixups.available.remove(fixup) |
| 387 | + return True |
| 388 | + return False |
| 389 | + |
| 390 | + async def _apply_fixups(self, messages: t.Sequence[Message]) -> t.Sequence[Message]: |
| 391 | + """ |
| 392 | + Apply all active fixups to the messages. |
| 393 | +
|
| 394 | + Args: |
| 395 | + messages: The messages to be fixed. |
| 396 | +
|
| 397 | + Returns: |
| 398 | + The fixed messages. |
| 399 | + """ |
| 400 | + current_messages = messages |
| 401 | + for fixup in self._fixups.active: |
| 402 | + current_messages = fixup.fix(current_messages) |
| 403 | + return current_messages |
| 404 | + |
326 | 405 | async def generate_messages( |
327 | 406 | self, |
328 | 407 | messages: t.Sequence[t.Sequence[Message]], |
@@ -381,14 +460,16 @@ def chat( |
381 | 460 | self, |
382 | 461 | messages: t.Sequence[MessageDict], |
383 | 462 | params: GenerateParams | None = None, |
384 | | - ) -> ChatPipeline: ... |
| 463 | + ) -> ChatPipeline: |
| 464 | + ... |
385 | 465 |
|
386 | 466 | @t.overload |
387 | 467 | def chat( |
388 | 468 | self, |
389 | 469 | messages: t.Sequence[Message] | MessageDict | Message | str | None = None, |
390 | 470 | params: GenerateParams | None = None, |
391 | | - ) -> ChatPipeline: ... |
| 471 | + ) -> ChatPipeline: |
| 472 | + ... |
392 | 473 |
|
393 | 474 | def chat( |
394 | 475 | self, |
@@ -457,15 +538,17 @@ def chat( |
457 | 538 | generator: Generator, |
458 | 539 | messages: t.Sequence[MessageDict], |
459 | 540 | params: GenerateParams | None = None, |
460 | | -) -> ChatPipeline: ... |
| 541 | +) -> ChatPipeline: |
| 542 | + ... |
461 | 543 |
|
462 | 544 |
|
463 | 545 | @t.overload |
464 | 546 | def chat( |
465 | 547 | generator: Generator, |
466 | 548 | messages: t.Sequence[Message] | MessageDict | Message | str | None = None, |
467 | 549 | params: GenerateParams | None = None, |
468 | | -) -> ChatPipeline: ... |
| 550 | +) -> ChatPipeline: |
| 551 | + ... |
469 | 552 |
|
470 | 553 |
|
471 | 554 | def chat( |
|
0 commit comments