Skip to content

Commit a3159b5

Browse files
committed
replace line by linee asserts with snapshots
1 parent 6dac474 commit a3159b5

File tree

1 file changed

+18
-66
lines changed

1 file changed

+18
-66
lines changed

tests/test_parts_manager.py

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
TextPart,
1414
TextPartDelta,
1515
ThinkingPart,
16+
ThinkingPartDelta,
1617
ToolCallPart,
1718
ToolCallPartDelta,
1819
UnexpectedModelBehavior,
@@ -28,65 +29,35 @@ def test_handle_text_deltas(vendor_part_id: str | None):
2829
assert manager.get_parts() == []
2930

3031
events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello '))
31-
assert len(events) == 1
32-
event = events[0]
33-
assert event == snapshot(
34-
PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start')
35-
)
32+
assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='hello '))])
3633
assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')])
3734

3835
events = list(manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world'))
39-
assert len(events) == 1, 'Test returned more than one event.'
40-
event = events[0]
41-
assert event == snapshot(
42-
PartDeltaEvent(
43-
index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta'
44-
)
45-
)
36+
assert events == snapshot([PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))])
4637
assert manager.get_parts() == snapshot([TextPart(content='hello world', part_kind='text')])
4738

4839

4940
def test_handle_dovetailed_text_deltas():
5041
manager = ModelResponsePartsManager()
5142

5243
events = list(manager.handle_text_delta(vendor_part_id='first', content='hello '))
53-
assert len(events) == 1, 'Test returned more than one event.'
54-
event = events[0]
55-
assert event == snapshot(
56-
PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start')
57-
)
44+
assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='hello '))])
5845
assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')])
5946

6047
events = list(manager.handle_text_delta(vendor_part_id='second', content='goodbye '))
61-
assert len(events) == 1, 'Test returned more than one event.'
62-
event = events[0]
63-
assert event == snapshot(
64-
PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start')
65-
)
48+
assert events == snapshot([PartStartEvent(index=1, part=TextPart(content='goodbye '))])
6649
assert manager.get_parts() == snapshot(
6750
[TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')]
6851
)
6952

7053
events = list(manager.handle_text_delta(vendor_part_id='first', content='world'))
71-
assert len(events) == 1, 'Test returned more than one event.'
72-
event = events[0]
73-
assert event == snapshot(
74-
PartDeltaEvent(
75-
index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta'
76-
)
77-
)
54+
assert events == snapshot([PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))])
7855
assert manager.get_parts() == snapshot(
7956
[TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')]
8057
)
8158

8259
events = list(manager.handle_text_delta(vendor_part_id='second', content='Samuel'))
83-
assert len(events) == 1, 'Test returned more than one event.'
84-
event = events[0]
85-
assert event == snapshot(
86-
PartDeltaEvent(
87-
index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta'
88-
)
89-
)
60+
assert events == snapshot([PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Samuel'))])
9061
assert manager.get_parts() == snapshot(
9162
[TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye Samuel', part_kind='text')]
9263
)
@@ -307,11 +278,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non
307278
manager = ModelResponsePartsManager()
308279

309280
events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello '))
310-
assert len(events) == 1, 'Test returned more than one event.'
311-
event = events[0]
312-
assert event == snapshot(
313-
PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start')
314-
)
281+
assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='hello '))])
315282
assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')])
316283

317284
event = manager.handle_tool_call_delta(
@@ -326,16 +293,8 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non
326293
)
327294

328295
events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world'))
329-
assert len(events) == 1, 'Test returned more than one event.'
330-
event = events[0]
331296
if text_vendor_part_id is None:
332-
assert event == snapshot(
333-
PartStartEvent(
334-
index=2,
335-
part=TextPart(content='world', part_kind='text'),
336-
event_kind='part_start',
337-
)
338-
)
297+
assert events == snapshot([PartStartEvent(index=2, part=TextPart(content='world'))])
339298
assert manager.get_parts() == snapshot(
340299
[
341300
TextPart(content='hello ', part_kind='text'),
@@ -344,11 +303,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non
344303
]
345304
)
346305
else:
347-
assert event == snapshot(
348-
PartDeltaEvent(
349-
index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta'
350-
)
351-
)
306+
assert events == snapshot([PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))])
352307
assert manager.get_parts() == snapshot(
353308
[
354309
TextPart(content='hello world', part_kind='text'),
@@ -465,14 +420,12 @@ def test_handle_thinking_delta_no_vendor_id_with_existing_thinking_part():
465420
manager = ModelResponsePartsManager()
466421

467422
# Add a thinking part first
468-
event = next(manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None))
469-
assert isinstance(event, PartStartEvent)
470-
assert event.index == 0
423+
events = list(manager.handle_thinking_delta(vendor_part_id='first', content='initial thought', signature=None))
424+
assert events == snapshot([PartStartEvent(index=0, part=ThinkingPart(content='initial thought'))])
471425

472426
# Now add another thinking delta with no vendor_part_id - should update the latest thinking part
473-
event = next(manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None))
474-
assert isinstance(event, PartDeltaEvent)
475-
assert event.index == 0
427+
events = list(manager.handle_thinking_delta(vendor_part_id=None, content=' more', signature=None))
428+
assert events == snapshot([PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' more'))])
476429

477430
parts = manager.get_parts()
478431
assert parts == snapshot([ThinkingPart(content='initial thought more')])
@@ -494,9 +447,8 @@ def test_handle_thinking_delta_wrong_part_type():
494447
def test_handle_thinking_delta_new_part_with_vendor_id():
495448
manager = ModelResponsePartsManager()
496449

497-
event = next(manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None))
498-
assert isinstance(event, PartStartEvent)
499-
assert event.index == 0
450+
events = list(manager.handle_thinking_delta(vendor_part_id='thinking', content='new thought', signature=None))
451+
assert events == snapshot([PartStartEvent(index=0, part=ThinkingPart(content='new thought'))])
500452

501453
parts = manager.get_parts()
502454
assert parts == snapshot([ThinkingPart(content='new thought')])
@@ -563,9 +515,9 @@ def test_handle_thinking_delta_when_latest_is_not_thinking():
563515

564516
# Call handle_thinking_delta with vendor_part_id=None
565517
# Should create NEW ThinkingPart instead of trying to update TextPart
566-
event = next(manager.handle_thinking_delta(vendor_part_id=None, content='thinking'))
518+
events = list(manager.handle_thinking_delta(vendor_part_id=None, content='thinking'))
567519

568-
assert event == snapshot(PartStartEvent(index=1, part=ThinkingPart(content='thinking')))
520+
assert events == snapshot([PartStartEvent(index=1, part=ThinkingPart(content='thinking'))])
569521
assert manager.get_parts() == snapshot([TextPart(content='text'), ThinkingPart(content='thinking')])
570522

571523

0 commit comments

Comments
 (0)