Skip to content

Commit 558e9c4

Browse files
committed
Add Fixup mechanics for generators with a new entry to tool calls with multi-modal content in OpenAI
1 parent 1cd42d3 commit 558e9c4

File tree

5 files changed

+170
-13
lines changed

5 files changed

+170
-13
lines changed

rigging/generator/base.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import abc
34
import inspect
45
import typing as t
6+
from dataclasses import dataclass, field
57

68
from loguru import logger
79
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator
@@ -21,16 +23,59 @@
2123

2224
CallableT = t.TypeVar("CallableT", bound=t.Callable[..., t.Any])
2325

26+
T = t.TypeVar("T")
27+
2428
# Global provider map
2529

2630

2731
@t.runtime_checkable
2832
class LazyGenerator(t.Protocol):
29-
def __call__(self) -> type[Generator]: ...
33+
def __call__(self) -> type[Generator]:
34+
...
3035

3136

3237
g_providers: dict[str, type[Generator] | LazyGenerator] = {}
3338

39+
# Fixups
40+
41+
42+
class Fixup(abc.ABC):
43+
"""
44+
Base class for fixups that apply on message sequences to correct errors.
45+
"""
46+
47+
@abc.abstractmethod
48+
def can_fix(self, exception: Exception) -> bool:
49+
"""
50+
Check if the fixup can resolve the given exception if made active.
51+
52+
Args:
53+
exception: The exception to be checked.
54+
55+
Returns:
56+
Whether the fixup can handle the exception.
57+
"""
58+
...
59+
60+
@abc.abstractmethod
61+
def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
62+
"""
63+
Process a sequence of messages to fix them.
64+
65+
Args:
66+
messages: The messages to be fixed.
67+
68+
Returns:
69+
The fixed messages.
70+
"""
71+
...
72+
73+
74+
@dataclass
75+
class Fixups:
76+
available: list[Fixup] = field(default_factory=list)
77+
active: list[Fixup] = field(default_factory=list)
78+
3479

3580
# TODO: We also would like to support N-style
3681
# parallel generation eventually -> need to
@@ -251,6 +296,8 @@ class Generator(BaseModel):
251296
_watch_callbacks: list[WatchCallbacks] = []
252297
_wrap: t.Callable[[CallableT], CallableT] | None = None
253298

299+
_fixups: Fixups = Fixups()
300+
254301
def to_identifier(self, params: GenerateParams | None = None) -> str:
255302
"""
256303
Converts the generator instance back into a rigging identifier string.
@@ -323,6 +370,38 @@ def wrap(self, func: t.Callable[[CallableT], CallableT] | None) -> Self:
323370
self._wrap = func # type: ignore [assignment]
324371
return self
325372

373+
def _check_fixups(self, error: Exception) -> bool:
374+
"""
375+
Check if any fixer can handle this error.
376+
377+
Args:
378+
error: The error to be checked.
379+
380+
Returns:
381+
Whether a fixer was able to handle the error.
382+
"""
383+
for fixup in self._fixups.available[:]:
384+
if fixup.can_fix(error):
385+
self._fixups.active.append(fixup)
386+
self._fixups.available.remove(fixup)
387+
return True
388+
return False
389+
390+
async def _apply_fixups(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
391+
"""
392+
Apply all active fixups to the messages.
393+
394+
Args:
395+
messages: The messages to be fixed.
396+
397+
Returns:
398+
The fixed messages.
399+
"""
400+
current_messages = messages
401+
for fixup in self._fixups.active:
402+
current_messages = fixup.fix(current_messages)
403+
return current_messages
404+
326405
async def generate_messages(
327406
self,
328407
messages: t.Sequence[t.Sequence[Message]],
@@ -381,14 +460,16 @@ def chat(
381460
self,
382461
messages: t.Sequence[MessageDict],
383462
params: GenerateParams | None = None,
384-
) -> ChatPipeline: ...
463+
) -> ChatPipeline:
464+
...
385465

386466
@t.overload
387467
def chat(
388468
self,
389469
messages: t.Sequence[Message] | MessageDict | Message | str | None = None,
390470
params: GenerateParams | None = None,
391-
) -> ChatPipeline: ...
471+
) -> ChatPipeline:
472+
...
392473

393474
def chat(
394475
self,
@@ -457,15 +538,17 @@ def chat(
457538
generator: Generator,
458539
messages: t.Sequence[MessageDict],
459540
params: GenerateParams | None = None,
460-
) -> ChatPipeline: ...
541+
) -> ChatPipeline:
542+
...
461543

462544

463545
@t.overload
464546
def chat(
465547
generator: Generator,
466548
messages: t.Sequence[Message] | MessageDict | Message | str | None = None,
467549
params: GenerateParams | None = None,
468-
) -> ChatPipeline: ...
550+
) -> ChatPipeline:
551+
...
469552

470553

471554
def chat(

rigging/generator/litellm_.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from loguru import logger
1010

1111
from rigging.generator.base import (
12+
Fixup,
13+
Fixups,
1214
GeneratedMessage,
1315
GeneratedText,
1416
GenerateParams,
@@ -23,6 +25,42 @@
2325
# fix it to prevent confusion
2426
litellm.drop_params = True
2527

28+
# Prevent the small debug statements
29+
# from being printed to the console
30+
litellm.suppress_debug_info = True
31+
32+
33+
class OpenAIToolsWithImageURLsFixup(Fixup):
34+
# As of writing, openai doesn't support multi-part messages
35+
# associated with the `tool` role. This is complicated by
36+
# the fact that we need to resolve the tool call(s) in the
37+
# following messages. To get around this, we'll resolve the tool
38+
# call with empty content, and duplicate the multi-part data
39+
# into a user message immediately following it. We also need
40+
# to take care of multiple tool calls next to eachother and ensure
41+
# we don't add the user message in between them.
42+
43+
def can_fix(self, exception: Exception) -> bool:
44+
return (
45+
"Image URLs are only allowed for messages with role 'user', but this message with role 'tool' contains an image URL."
46+
in str(exception)
47+
)
48+
49+
def fix(self, items: t.Sequence[Message]) -> t.Sequence[Message]:
50+
updated_messages: list[Message] = []
51+
append_queue: list[Message] = []
52+
for message in items:
53+
if message.role == "tool" and isinstance(message.all_content, list):
54+
updated_messages.append(message.model_copy(deep=True, update={"all_content": "See next message"}))
55+
append_queue.append(message.model_copy(deep=True, update={"role": "user"}))
56+
else:
57+
updated_messages.extend(append_queue)
58+
append_queue = []
59+
updated_messages.append(message)
60+
61+
updated_messages.extend(append_queue)
62+
return updated_messages
63+
2664

2765
class LiteLLMGenerator(Generator):
2866
"""
@@ -65,6 +103,8 @@ class LiteLLMGenerator(Generator):
65103
_semaphore: asyncio.Semaphore | None = None
66104
_last_request_time: datetime.datetime | None = None
67105

106+
_fixups = Fixups(available=[OpenAIToolsWithImageURLsFixup()])
107+
68108
@property
69109
def semaphore(self) -> asyncio.Semaphore:
70110
if self._semaphore is None:
@@ -155,12 +195,20 @@ async def _generate_message(self, messages: t.Sequence[Message], params: Generat
155195
if self._wrap is not None:
156196
acompletion = self._wrap(acompletion)
157197

158-
response = await acompletion(
159-
model=self.model,
160-
messages=[message.to_openai_spec() for message in messages],
161-
api_key=self.api_key,
162-
**self.params.merge_with(params).to_dict(),
163-
)
198+
# Prepare messages for specific providers
199+
messages = await self._apply_fixups(messages)
200+
201+
try:
202+
response = await acompletion(
203+
model=self.model,
204+
messages=[message.to_openai_spec() for message in messages],
205+
api_key=self.api_key,
206+
**self.params.merge_with(params).to_dict(),
207+
)
208+
except Exception as e:
209+
if self._check_fixups(e):
210+
return await self._generate_message(messages, params)
211+
raise
164212

165213
self._last_request_time = datetime.datetime.now()
166214
return self._parse_model_response(response)

rigging/message.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from rigging.model import Model, ModelT # noqa: TCH001
3131
from rigging.parsing import try_parse_many
3232
from rigging.tool.api import ToolCall
33+
from rigging.util import truncate_string
3334

3435
Role = t.Literal["system", "user", "assistant", "tool"]
3536
"""The role of a message. Can be 'system', 'user', 'assistant', or 'tool'."""
@@ -84,6 +85,9 @@ class ContentText(BaseModel):
8485
text: str
8586
"""The text content."""
8687

88+
def __str__(self) -> str:
89+
return self.text
90+
8791

8892
class ContentImageUrl(BaseModel):
8993
"""An image URL content part of a message."""
@@ -116,6 +120,9 @@ def from_file(cls, file: Path | str, *, mimetype: str | None = None) -> ContentI
116120

117121
return cls(image_url=cls.ImageUrl(url=url))
118122

123+
def __str__(self) -> str:
124+
return f"<ContentImageUrl '{truncate_string(self.image_url.url, 50)}'>"
125+
119126

120127
Content = t.Union[ContentText, ContentImageUrl]
121128
"""The types of content that can be included in a message."""
@@ -181,9 +188,15 @@ def __init__(
181188
)
182189

183190
def __str__(self) -> str:
184-
formatted = f"[{self.role}]: {self.content}"
191+
formatted = f"[{self.role}]:"
192+
if isinstance(self.all_content, list):
193+
formatted += "\n |- " + "\n |- ".join(str(content) for content in self.all_content)
194+
else:
195+
formatted += f" {self.content}"
196+
185197
for tool_call in self.tool_calls or []:
186198
formatted += f"\n |- {tool_call}"
199+
187200
return formatted
188201

189202
def __len__(self) -> int:

rigging/tool/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ToolCall(BaseModel):
8080
function: FunctionCall
8181

8282
def __str__(self) -> str:
83-
return f"{self.function.name}({self.function.arguments})"
83+
return f"<ToolCall {self.function.name}({self.function.arguments})>"
8484

8585

8686
class ApiTool:

rigging/util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,16 @@ def get_qualified_name(obj: t.Callable[..., t.Any]) -> str:
148148

149149
# Fallback
150150
return obj.__class__.__qualname__
151+
152+
153+
# Formatting
154+
155+
156+
def truncate_string(content: str, max_length: int, *, sep: str = "...") -> str:
157+
"""Return a string at most max_length characters long."""
158+
if len(content) <= max_length:
159+
return content
160+
161+
remaining = max_length - len(sep)
162+
middle = remaining // 2
163+
return content[:middle] + sep + content[-middle:]

0 commit comments

Comments
 (0)