@@ -177,7 +177,7 @@ def test_chat_restart() -> None:
177177 [
178178 Message ("user" , "Other Stuff" ),
179179 ],
180- generator = get_generator ("gpt-3.5 " ),
180+ generator = get_generator ("base " ),
181181 )
182182
183183 assert len (chat .restart ()) == 2
@@ -196,7 +196,7 @@ def test_chat_continue() -> None:
196196 Message ("user" , "Hello" ),
197197 Message ("assistant" , "Hi there!" ),
198198 ],
199- generator = get_generator ("gpt-3.5 " ),
199+ generator = get_generator ("base " ),
200200 )
201201
202202 continued = chat .continue_ ([Message ("user" , "How are you?" )]).chat
@@ -213,7 +213,7 @@ def test_chat_to_message_dicts() -> None:
213213 Message ("user" , "Hello" ),
214214 Message ("assistant" , "Hi there!" ),
215215 ],
216- generator = get_generator ("gpt-3.5 " ),
216+ generator = get_generator ("base " ),
217217 )
218218
219219 assert len (chat .message_dicts ) == 2
@@ -227,7 +227,7 @@ def test_chat_to_conversation() -> None:
227227 Message ("user" , "Hello" ),
228228 Message ("assistant" , "Hi there!" ),
229229 ],
230- generator = get_generator ("gpt-3.5 " ),
230+ generator = get_generator ("base " ),
231231 )
232232
233233 assert "[user]: Hello" in chat .conversation
@@ -254,7 +254,7 @@ def test_chat_properties() -> None:
254254
255255
256256def test_chat_pipeline_continue () -> None :
257- pipeline = ChatPipeline (get_generator ("gpt-3.5 " ), [])
257+ pipeline = ChatPipeline (get_generator ("base " ), [])
258258 continued = pipeline .fork ([Message ("user" , "Hello" )])
259259
260260 assert continued != pipeline
@@ -263,7 +263,7 @@ def test_chat_pipeline_continue() -> None:
263263
264264
265265def test_chat_pipeline_add () -> None :
266- pipeline = ChatPipeline (get_generator ("gpt-3.5 " ), [Message ("user" , "Hello" )])
266+ pipeline = ChatPipeline (get_generator ("base " ), [Message ("user" , "Hello" )])
267267 added = pipeline .add (Message ("user" , "There" ))
268268
269269 assert added == pipeline
@@ -282,7 +282,7 @@ def test_chat_continue_maintains_parsed_models() -> None:
282282 Message ("user" , "<person name='John'>30</person>" ),
283283 Message ("assistant" , "<address><street>123 Main St</street><city>Anytown</city></address>" ),
284284 ],
285- generator = get_generator ("gpt-3.5 " ),
285+ generator = get_generator ("base " ),
286286 )
287287
288288 chat .all [0 ].parse (Person )
@@ -296,14 +296,14 @@ def test_chat_continue_maintains_parsed_models() -> None:
296296
297297
298298def test_chat_pipeline_meta () -> None :
299- pipeline = ChatPipeline (get_generator ("gpt-3.5 " ), [Message ("user" , "Hello" )])
299+ pipeline = ChatPipeline (get_generator ("base " ), [Message ("user" , "Hello" )])
300300 with_meta = pipeline .meta (key = "value" )
301301 assert with_meta == pipeline
302302 assert with_meta .metadata == {"key" : "value" }
303303
304304
305305def test_chat_pipeline_with () -> None :
306- pipeline = ChatPipeline (get_generator ("gpt-3.5 " ), [Message ("user" , "Hello" )])
306+ pipeline = ChatPipeline (get_generator ("base " ), [Message ("user" , "Hello" )])
307307 with_pipeline = pipeline .with_ (GenerateParams (max_tokens = 123 ))
308308 assert with_pipeline == pipeline
309309 assert with_pipeline .params is not None
@@ -368,3 +368,110 @@ def test_message_dedent() -> None:
368368 assert lines [0 ] == "Tabbed content"
369369 assert lines [1 ] == "Line 2"
370370 assert lines [2 ] == ""
371+
372+
373+ def test_chat_pipeline_add_merge_strategy_default () -> None :
374+ """Test the default merge strategy (only-user-role) behavior."""
375+ pipeline = ChatPipeline (get_generator ("base" ), [Message ("user" , "Hello" )])
376+
377+ # Test merging user messages (should merge)
378+ pipeline .add (Message ("user" , "There" ))
379+ assert len (pipeline .chat ) == 1
380+ assert pipeline .chat .all [0 ].content == "Hello\n There"
381+
382+ # Test adding assistant message after merged user messages
383+ pipeline .add (Message ("assistant" , "Hi there!" ))
384+ assert len (pipeline .chat ) == 2
385+ assert pipeline .chat .all [1 ].content == "Hi there!"
386+
387+ # Test that assistant messages don't merge by default
388+ pipeline .add (Message ("assistant" , "How are you?" ))
389+ assert len (pipeline .chat ) == 3
390+ assert pipeline .chat .all [2 ].content == "How are you?"
391+
392+
393+ def test_chat_pipeline_add_merge_strategy_none () -> None :
394+ """Test that merge_strategy='none' prevents any message merging."""
395+ pipeline = ChatPipeline (get_generator ("base" ), [Message ("user" , "Hello" )])
396+
397+ # Test that user messages don't merge
398+ pipeline .add (Message ("user" , "There" ), merge_strategy = "none" )
399+ assert len (pipeline .chat ) == 2
400+ assert pipeline .chat .all [0 ].content == "Hello"
401+ assert pipeline .chat .all [1 ].content == "There"
402+
403+ # Test that assistant messages also don't merge
404+ pipeline .add ([Message ("assistant" , "Hi!" ), Message ("assistant" , "How are you?" )], merge_strategy = "none" )
405+ assert len (pipeline .chat ) == 4
406+ assert pipeline .chat .all [2 ].content == "Hi!"
407+ assert pipeline .chat .all [3 ].content == "How are you?"
408+
409+
410+ def test_chat_pipeline_add_merge_strategy_all () -> None :
411+ """Test that merge_strategy='all' merges consecutive messages of the same role."""
412+ pipeline = ChatPipeline (get_generator ("base" ), [Message ("user" , "Hello" )])
413+
414+ # Test merging user messages
415+ pipeline .add (Message ("user" , "There" ), merge_strategy = "all" )
416+ assert len (pipeline .chat ) == 1
417+ assert pipeline .chat .all [0 ].content == "Hello\n There"
418+
419+ # Test merging assistant messages
420+ pipeline .add (Message ("assistant" , "Hi!" ))
421+ pipeline .add (Message ("assistant" , "How are you?" ), merge_strategy = "all" )
422+ assert len (pipeline .chat ) == 2
423+ assert pipeline .chat .all [1 ].content == "Hi!\n How are you?"
424+
425+
426+ def test_chat_pipeline_add_merge_strategy_multiple_messages () -> None :
427+ """Test merge behavior with multiple messages in a single add call."""
428+ pipeline = ChatPipeline (get_generator ("base" ), [Message ("user" , "Hello" )])
429+
430+ # Test adding multiple messages with user first (should merge first message only)
431+ pipeline .add (
432+ [Message ("user" , "There" ), Message ("assistant" , "Hi!" ), Message ("user" , "Another" )], merge_strategy = "all"
433+ )
434+
435+ assert len (pipeline .chat ) == 3
436+ assert pipeline .chat .all [0 ].content == "Hello\n There"
437+ assert pipeline .chat .all [1 ].content == "Hi!"
438+ assert pipeline .chat .all [2 ].content == "Another"
439+
440+
441+ def test_chat_pipeline_add_merge_strategy_empty_chat () -> None :
442+ """Test merge behavior when starting with an empty chat."""
443+ pipeline = ChatPipeline (get_generator ("base" ), [])
444+
445+ # Test that first message is added normally regardless of strategy
446+ pipeline .add (Message ("user" , "Hello" ), merge_strategy = "all" )
447+ assert len (pipeline .chat ) == 1
448+ assert pipeline .chat .all [0 ].content == "Hello"
449+
450+ pipeline .add (Message ("user" , "There" ), merge_strategy = "all" )
451+ assert len (pipeline .chat ) == 1
452+ assert pipeline .chat .all [0 ].content == "Hello\n There"
453+
454+
455+ def test_chat_pipeline_add_merge_strategy_string_input () -> None :
456+ """Test merge behavior with string input."""
457+ pipeline = ChatPipeline (get_generator ("base" ), [Message ("user" , "Hello" )])
458+
459+ # Test merging string input (converts to user message)
460+ pipeline .add ("There" , merge_strategy = "all" )
461+ assert len (pipeline .chat ) == 1
462+ assert pipeline .chat .all [0 ].content == "Hello\n There"
463+
464+
465+ def test_chat_pipeline_add_merge_strategy_message_dict () -> None :
466+ """Test merge behavior with MessageDict input."""
467+ pipeline = ChatPipeline (get_generator ("base" ), [Message ("user" , "Hello" )])
468+
469+ # Test merging MessageDict input
470+ pipeline .add ({"role" : "user" , "content" : "There" }, merge_strategy = "all" )
471+ assert len (pipeline .chat ) == 1
472+ assert pipeline .chat .all [0 ].content == "Hello\n There"
473+
474+ # Test non-merging with different role
475+ pipeline .add ({"role" : "assistant" , "content" : "Hi!" }, merge_strategy = "all" )
476+ assert len (pipeline .chat ) == 2
477+ assert pipeline .chat .all [1 ].content == "Hi!"
0 commit comments