Skip to content

Commit 68576b1

Browse files
committed
Fix message merging behavior for tool calls. Make merge behavior in Chat.add() more explicit. Sanity check for provider_specific_fields.
1 parent 7d910e2 commit 68576b1

File tree

3 files changed

+137
-19
lines changed

3 files changed

+137
-19
lines changed

rigging/chat.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -586,28 +586,39 @@ async def log(chats: list[Chat]) -> None:
586586
return self
587587

588588
def add(
589-
self, messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str
589+
self,
590+
messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str,
591+
*,
592+
merge_strategy: t.Literal["only-user-role", "all", "none"] = "only-user-role",
590593
) -> ChatPipeline:
591594
"""
592595
Appends new message(s) to the internal chat before generation.
593596
594597
Note:
595-
If the last message in the chat is the same role as the first new message,
596-
the content will be appended. instead of a new message being created.
598+
`merge_strategy` configures behavior when the last message in the chat
599+
is the same role as the first incoming message. This is useful for appending content
600+
automatically to avoid duplicate messages of the same role. For backwards compatibility,
601+
the default behavior is currently set to `only-user-role`. It can be set to `none` to disable
602+
any merging behavior, which may become the default in the future.
597603
598604
Args:
599605
messages: The messages to be added to the chat. It can be a single message or a sequence of messages.
606+
merge_strategy: The strategy to use when merging message content when the roles match.
607+
- "only-user-role": Only merge content of the last existing message and the first incoming message if the last message role is "user".
608+
- "all": Merge content of the last existing message and the first incoming message if their roles match.
609+
- "none": Keep messages independent and do not merge any content.
600610
601611
Returns:
602612
The updated ChatPipeline object.
603613
"""
604614
message_list = Message.fit_as_list(messages)
605-
# If the last message is the same role as the first new message, append to it
606-
if self.chat.all and self.chat.all[-1].role == message_list[0].role:
607-
self.chat.all[-1].content += "\n" + message_list[0].content
608-
message_list = message_list[1:]
609-
else:
610-
self.chat.generated += message_list
615+
616+
if merge_strategy != "none" and self.chat.all and self.chat.all[-1].role == message_list[0].role:
617+
if merge_strategy == "all" or (merge_strategy == "only-user-role" and self.chat.all[-1].role == "user"):
618+
self.chat.all[-1].content += "\n" + message_list[0].content
619+
message_list = message_list[1:]
620+
621+
self.chat.generated += message_list
611622
return self
612623

613624
def fork(

rigging/generator/litellm_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _parse_model_response(self, response: litellm.types.utils.ModelResponse) ->
121121
extra = {"response_id": response.id}
122122
if hasattr(response, "provider"):
123123
extra["provider"] = response.provider
124-
if choice.message.provider_specific_fields is not None:
124+
if hasattr(choice.message, "provider_specific_fields") and choice.message.provider_specific_fields is not None:
125125
extra.update(choice.message.provider_specific_fields)
126126

127127
return GeneratedMessage(

tests/test_chat.py

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_chat_restart() -> None:
177177
[
178178
Message("user", "Other Stuff"),
179179
],
180-
generator=get_generator("gpt-3.5"),
180+
generator=get_generator("base"),
181181
)
182182

183183
assert len(chat.restart()) == 2
@@ -196,7 +196,7 @@ def test_chat_continue() -> None:
196196
Message("user", "Hello"),
197197
Message("assistant", "Hi there!"),
198198
],
199-
generator=get_generator("gpt-3.5"),
199+
generator=get_generator("base"),
200200
)
201201

202202
continued = chat.continue_([Message("user", "How are you?")]).chat
@@ -213,7 +213,7 @@ def test_chat_to_message_dicts() -> None:
213213
Message("user", "Hello"),
214214
Message("assistant", "Hi there!"),
215215
],
216-
generator=get_generator("gpt-3.5"),
216+
generator=get_generator("base"),
217217
)
218218

219219
assert len(chat.message_dicts) == 2
@@ -227,7 +227,7 @@ def test_chat_to_conversation() -> None:
227227
Message("user", "Hello"),
228228
Message("assistant", "Hi there!"),
229229
],
230-
generator=get_generator("gpt-3.5"),
230+
generator=get_generator("base"),
231231
)
232232

233233
assert "[user]: Hello" in chat.conversation
@@ -254,7 +254,7 @@ def test_chat_properties() -> None:
254254

255255

256256
def test_chat_pipeline_continue() -> None:
257-
pipeline = ChatPipeline(get_generator("gpt-3.5"), [])
257+
pipeline = ChatPipeline(get_generator("base"), [])
258258
continued = pipeline.fork([Message("user", "Hello")])
259259

260260
assert continued != pipeline
@@ -263,7 +263,7 @@ def test_chat_pipeline_continue() -> None:
263263

264264

265265
def test_chat_pipeline_add() -> None:
266-
pipeline = ChatPipeline(get_generator("gpt-3.5"), [Message("user", "Hello")])
266+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
267267
added = pipeline.add(Message("user", "There"))
268268

269269
assert added == pipeline
@@ -282,7 +282,7 @@ def test_chat_continue_maintains_parsed_models() -> None:
282282
Message("user", "<person name='John'>30</person>"),
283283
Message("assistant", "<address><street>123 Main St</street><city>Anytown</city></address>"),
284284
],
285-
generator=get_generator("gpt-3.5"),
285+
generator=get_generator("base"),
286286
)
287287

288288
chat.all[0].parse(Person)
@@ -296,14 +296,14 @@ def test_chat_continue_maintains_parsed_models() -> None:
296296

297297

298298
def test_chat_pipeline_meta() -> None:
299-
pipeline = ChatPipeline(get_generator("gpt-3.5"), [Message("user", "Hello")])
299+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
300300
with_meta = pipeline.meta(key="value")
301301
assert with_meta == pipeline
302302
assert with_meta.metadata == {"key": "value"}
303303

304304

305305
def test_chat_pipeline_with() -> None:
306-
pipeline = ChatPipeline(get_generator("gpt-3.5"), [Message("user", "Hello")])
306+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
307307
with_pipeline = pipeline.with_(GenerateParams(max_tokens=123))
308308
assert with_pipeline == pipeline
309309
assert with_pipeline.params is not None
@@ -368,3 +368,110 @@ def test_message_dedent() -> None:
368368
assert lines[0] == "Tabbed content"
369369
assert lines[1] == "Line 2"
370370
assert lines[2] == ""
371+
372+
373+
def test_chat_pipeline_add_merge_strategy_default() -> None:
374+
"""Test the default merge strategy (only-user-role) behavior."""
375+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
376+
377+
# Test merging user messages (should merge)
378+
pipeline.add(Message("user", "There"))
379+
assert len(pipeline.chat) == 1
380+
assert pipeline.chat.all[0].content == "Hello\nThere"
381+
382+
# Test adding assistant message after merged user messages
383+
pipeline.add(Message("assistant", "Hi there!"))
384+
assert len(pipeline.chat) == 2
385+
assert pipeline.chat.all[1].content == "Hi there!"
386+
387+
# Test that assistant messages don't merge by default
388+
pipeline.add(Message("assistant", "How are you?"))
389+
assert len(pipeline.chat) == 3
390+
assert pipeline.chat.all[2].content == "How are you?"
391+
392+
393+
def test_chat_pipeline_add_merge_strategy_none() -> None:
394+
"""Test that merge_strategy='none' prevents any message merging."""
395+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
396+
397+
# Test that user messages don't merge
398+
pipeline.add(Message("user", "There"), merge_strategy="none")
399+
assert len(pipeline.chat) == 2
400+
assert pipeline.chat.all[0].content == "Hello"
401+
assert pipeline.chat.all[1].content == "There"
402+
403+
# Test that assistant messages also don't merge
404+
pipeline.add([Message("assistant", "Hi!"), Message("assistant", "How are you?")], merge_strategy="none")
405+
assert len(pipeline.chat) == 4
406+
assert pipeline.chat.all[2].content == "Hi!"
407+
assert pipeline.chat.all[3].content == "How are you?"
408+
409+
410+
def test_chat_pipeline_add_merge_strategy_all() -> None:
411+
"""Test that merge_strategy='all' merges consecutive messages of the same role."""
412+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
413+
414+
# Test merging user messages
415+
pipeline.add(Message("user", "There"), merge_strategy="all")
416+
assert len(pipeline.chat) == 1
417+
assert pipeline.chat.all[0].content == "Hello\nThere"
418+
419+
# Test merging assistant messages
420+
pipeline.add(Message("assistant", "Hi!"))
421+
pipeline.add(Message("assistant", "How are you?"), merge_strategy="all")
422+
assert len(pipeline.chat) == 2
423+
assert pipeline.chat.all[1].content == "Hi!\nHow are you?"
424+
425+
426+
def test_chat_pipeline_add_merge_strategy_multiple_messages() -> None:
427+
"""Test merge behavior with multiple messages in a single add call."""
428+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
429+
430+
# Test adding multiple messages with user first (should merge first message only)
431+
pipeline.add(
432+
[Message("user", "There"), Message("assistant", "Hi!"), Message("user", "Another")], merge_strategy="all"
433+
)
434+
435+
assert len(pipeline.chat) == 3
436+
assert pipeline.chat.all[0].content == "Hello\nThere"
437+
assert pipeline.chat.all[1].content == "Hi!"
438+
assert pipeline.chat.all[2].content == "Another"
439+
440+
441+
def test_chat_pipeline_add_merge_strategy_empty_chat() -> None:
442+
"""Test merge behavior when starting with an empty chat."""
443+
pipeline = ChatPipeline(get_generator("base"), [])
444+
445+
# Test that first message is added normally regardless of strategy
446+
pipeline.add(Message("user", "Hello"), merge_strategy="all")
447+
assert len(pipeline.chat) == 1
448+
assert pipeline.chat.all[0].content == "Hello"
449+
450+
pipeline.add(Message("user", "There"), merge_strategy="all")
451+
assert len(pipeline.chat) == 1
452+
assert pipeline.chat.all[0].content == "Hello\nThere"
453+
454+
455+
def test_chat_pipeline_add_merge_strategy_string_input() -> None:
456+
"""Test merge behavior with string input."""
457+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
458+
459+
# Test merging string input (converts to user message)
460+
pipeline.add("There", merge_strategy="all")
461+
assert len(pipeline.chat) == 1
462+
assert pipeline.chat.all[0].content == "Hello\nThere"
463+
464+
465+
def test_chat_pipeline_add_merge_strategy_message_dict() -> None:
466+
"""Test merge behavior with MessageDict input."""
467+
pipeline = ChatPipeline(get_generator("base"), [Message("user", "Hello")])
468+
469+
# Test merging MessageDict input
470+
pipeline.add({"role": "user", "content": "There"}, merge_strategy="all")
471+
assert len(pipeline.chat) == 1
472+
assert pipeline.chat.all[0].content == "Hello\nThere"
473+
474+
# Test non-merging with different role
475+
pipeline.add({"role": "assistant", "content": "Hi!"}, merge_strategy="all")
476+
assert len(pipeline.chat) == 2
477+
assert pipeline.chat.all[1].content == "Hi!"

0 commit comments

Comments
 (0)