Skip to content

Commit 16a80fd

Browse files
authored
fix: Groq Content Parts, Fixups Refactor, and Gemini Caching (#103)
* Refactored fixups. Corrected behaviors for Groq content parts not being pure text. * make compat flags a set
1 parent 4befc52 commit 16a80fd

File tree

3 files changed

+114
-61
lines changed

3 files changed

+114
-61
lines changed

rigging/generator/base.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
2+
import functools
23
import inspect
34
import typing as t
4-
from dataclasses import dataclass, field
55
from functools import lru_cache
66

77
from loguru import logger
@@ -42,7 +42,7 @@ class Fixup(abc.ABC):
4242
"""
4343

4444
@abc.abstractmethod
45-
def can_fix(self, exception: Exception) -> bool:
45+
def can_fix(self, exception: Exception) -> bool | t.Literal["once"]:
4646
"""
4747
Check if the fixup can resolve the given exception if made active.
4848
@@ -68,10 +68,62 @@ def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
6868
...
6969

7070

71-
@dataclass
72-
class Fixups:
73-
available: list[Fixup] = field(default_factory=list)
74-
active: list[Fixup] = field(default_factory=list)
71+
FixupCompatibleFunc = t.Callable[
72+
t.Concatenate[t.Any, t.Sequence[Message], P],
73+
t.Awaitable[R],
74+
]
75+
76+
77+
def with_fixups(
78+
*fixups: Fixup,
79+
) -> t.Callable[[FixupCompatibleFunc[P, R]], FixupCompatibleFunc[P, R]]:
80+
"""
81+
Decorator that adds fixup retry logic with persistent state.
82+
83+
Args:
84+
fixups: Sequence of fixups to try
85+
"""
86+
available_fixups: list[Fixup] = list(fixups)
87+
active_fixups: list[Fixup] = []
88+
once_fixups: list[Fixup] = []
89+
90+
def decorator(func: FixupCompatibleFunc[P, R]) -> FixupCompatibleFunc[P, R]:
91+
@functools.wraps(func)
92+
async def wrapper(
93+
self: t.Any,
94+
messages: t.Sequence[Message],
95+
*args: P.args,
96+
**kwargs: P.kwargs,
97+
) -> R:
98+
nonlocal available_fixups, active_fixups
99+
100+
for fixup in [*active_fixups, *once_fixups]:
101+
messages = fixup.fix(messages)
102+
103+
try:
104+
result = await func(self, messages, *args, **kwargs)
105+
available_fixups = [*available_fixups, *once_fixups]
106+
once_fixups.clear()
107+
except Exception as e:
108+
for fixup in list(available_fixups):
109+
if (can_fix := fixup.can_fix(e)) is False:
110+
continue
111+
112+
if can_fix == "once":
113+
once_fixups.append(fixup)
114+
else:
115+
active_fixups.append(fixup)
116+
available_fixups.remove(fixup)
117+
118+
return await wrapper(self, messages, *args, **kwargs)
119+
120+
raise
121+
122+
return result
123+
124+
return wrapper # type: ignore[return-value]
125+
126+
return decorator
75127

76128

77129
# TODO: We also would like to support N-style
@@ -305,8 +357,6 @@ class Generator(BaseModel):
305357
_watch_callbacks: list["WatchChatCallback | WatchCompletionCallback"] = []
306358
_wrap: t.Callable[[CallableT], CallableT] | None = None
307359

308-
_fixups: Fixups = Fixups()
309-
310360
def to_identifier(self, params: GenerateParams | None = None) -> str:
311361
"""
312362
Converts the generator instance back into a rigging identifier string.
@@ -393,38 +443,6 @@ async def supports_function_calling(self) -> bool | None:
393443
"""
394444
return None
395445

396-
def _check_fixups(self, error: Exception) -> bool:
397-
"""
398-
Check if any fixer can handle this error.
399-
400-
Args:
401-
error: The error to be checked.
402-
403-
Returns:
404-
Whether a fixer was able to handle the error.
405-
"""
406-
for fixup in self._fixups.available[:]:
407-
if fixup.can_fix(error):
408-
self._fixups.active.append(fixup)
409-
self._fixups.available.remove(fixup)
410-
return True
411-
return False
412-
413-
async def _apply_fixups(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
414-
"""
415-
Apply all active fixups to the messages.
416-
417-
Args:
418-
messages: The messages to be fixed.
419-
420-
Returns:
421-
The fixed messages.
422-
"""
423-
current_messages = messages
424-
for fixup in self._fixups.active:
425-
current_messages = fixup.fix(current_messages)
426-
return current_messages
427-
428446
async def generate_messages(
429447
self,
430448
messages: t.Sequence[t.Sequence[Message]],

rigging/generator/litellm_.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from rigging.generator.base import (
1212
Fixup,
13-
Fixups,
1413
GeneratedMessage,
1514
GeneratedText,
1615
GenerateParams,
1716
Generator,
1817
trace_messages,
1918
trace_str,
19+
with_fixups,
2020
)
2121
from rigging.message import ContentAudioInput, ContentImageUrl, ContentText, Message
2222
from rigging.tool.api import ApiFunctionDefinition, ApiToolDefinition
@@ -74,11 +74,40 @@ class CacheTooSmallFixup(Fixup):
7474
# are below a certain threshold can result in a 400
7575
# error from APIs (Vertex/Gemini).
7676

77+
def can_fix(self, exception: Exception) -> bool | t.Literal["once"]:
78+
return "once" if "Cached content is too small." in str(exception) else False
79+
80+
def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
81+
return [message.cache(False) for message in messages]
82+
83+
84+
class GroqAssistantContentFixup(Fixup):
85+
# Groq can complain if we try to send fully
86+
# structured content parts when working with
87+
# the assistant role.
88+
#
89+
# Compatibility flags are a poor workaround for the
90+
# fact that we don't have direct control over the
91+
# conversion to the OpenAI spec.
92+
7793
def can_fix(self, exception: Exception) -> bool:
78-
return "Cached content is too small." in str(exception)
94+
return "Groq" in str(exception) and "content' : value must be a string" in str(exception)
95+
96+
def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
97+
updated_messages: list[Message] = []
98+
for message in messages:
99+
if message.role == "assistant":
100+
message = message.clone() # noqa: PLW2901
101+
message._compability_flags.add("content_as_str") # noqa: SLF001
102+
updated_messages.append(message)
103+
return updated_messages
79104

80-
def fix(self, items: t.Sequence[Message]) -> t.Sequence[Message]:
81-
return [message.cache(False) for message in items]
105+
106+
g_fixups = [
107+
OpenAIToolsWithImageURLsFixup(),
108+
CacheTooSmallFixup(),
109+
GroqAssistantContentFixup(),
110+
]
82111

83112

84113
class LiteLLMGenerator(Generator):
@@ -123,8 +152,6 @@ class LiteLLMGenerator(Generator):
123152
_last_request_time: datetime.datetime | None = None
124153
_supports_function_calling: bool | None = None
125154

126-
_fixups = Fixups(available=[OpenAIToolsWithImageURLsFixup(), CacheTooSmallFixup()])
127-
128155
@property
129156
def semaphore(self) -> asyncio.Semaphore:
130157
if self._semaphore is None:
@@ -299,6 +326,7 @@ def _parse_text_completion_response(
299326
extra={"response_id": response.id},
300327
)
301328

329+
@with_fixups(*g_fixups)
302330
async def _generate_message(
303331
self,
304332
messages: t.Sequence[Message],
@@ -313,20 +341,12 @@ async def _generate_message(
313341
if self._wrap is not None:
314342
acompletion = self._wrap(acompletion)
315343

316-
# Prepare messages for specific providers
317-
messages = await self._apply_fixups(messages)
318-
319-
try:
320-
response = await acompletion(
321-
model=self.model,
322-
messages=[message.to_openai_spec() for message in messages],
323-
api_key=self.api_key,
324-
**self.params.merge_with(params).to_dict(),
325-
)
326-
except Exception as e:
327-
if self._check_fixups(e):
328-
return await self._generate_message(messages, params)
329-
raise
344+
response = await acompletion(
345+
model=self.model,
346+
messages=[message.to_openai_spec() for message in messages],
347+
api_key=self.api_key,
348+
**self.params.merge_with(params).to_dict(),
349+
)
330350

331351
self._last_request_time = datetime.datetime.now(tz=datetime.timezone.utc)
332352
return self._parse_model_response(response)

rigging/message.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def save(self, path: Path | str) -> None:
336336
"""The types of content that can be included in a message."""
337337
ContentTypes = (ContentText, ContentImageUrl, ContentAudioInput)
338338

339+
CompatibilityFlag = t.Literal["content_as_str"]
340+
339341

340342
class Message(BaseModel):
341343
"""
@@ -364,6 +366,8 @@ class Message(BaseModel):
364366
tool_call_id: str | None = Field(None)
365367
"""Associated call id if this message is a response to a tool call."""
366368

369+
_compability_flags: set[CompatibilityFlag] = set()
370+
367371
def __init__(
368372
self,
369373
role: Role,
@@ -546,7 +550,7 @@ def to_openai_spec(self) -> dict[str, t.Any]:
546550
isinstance(current, dict)
547551
and current.get("type") == "text"
548552
and next_.get("type") == "text"
549-
and not current.get("text", "").endswith("\n")
553+
and not str(current.get("text", "")).endswith("\n")
550554
):
551555
current["text"] += "\n"
552556

@@ -556,6 +560,17 @@ def to_openai_spec(self) -> dict[str, t.Any]:
556560
if isinstance(part, dict) and part.get("type") == "input_audio":
557561
part.get("input_audio", {}).pop("transcript", None)
558562

563+
# If enabled, we need to convert our content to a flat
564+
# string for API compatibility. Groq is an example of an API
565+
# which will complain for some roles if we send a list of content parts.
566+
567+
if "content_as_str" in self._compability_flags:
568+
obj["content"] = "".join(
569+
part["text"]
570+
for part in obj["content"]
571+
if isinstance(part, dict) and part.get("type") == "text"
572+
)
573+
559574
return obj
560575

561576
# TODO: In general the add/remove/sync_part methods are

0 commit comments

Comments
 (0)