Skip to content

Commit 4622eb5

Browse files
committed
tests
1 parent 0b1dea3 commit 4622eb5

File tree

3 files changed

+661
-42
lines changed

3 files changed

+661
-42
lines changed

pydantic_ai_slim/pydantic_ai/ui/event_stream.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def handle_event(self, event: SourceEvent) -> AsyncIterator[EventT]:
232232
yield e
233233
case FunctionToolCallEvent():
234234
async for e in self.handle_function_tool_call(event):
235-
yield e # TODO (DouweM): coverage
235+
yield e
236236
case FunctionToolResultEvent():
237237
async for e in self.handle_function_tool_result(event):
238238
yield e
@@ -266,7 +266,7 @@ async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT
266266
case BuiltinToolReturnPart():
267267
async for e in self.handle_builtin_tool_return(part):
268268
yield e
269-
case FilePart(): # TODO (DouweM): coverage
269+
case FilePart(): # pragma: no branch
270270
async for e in self.handle_file(part):
271271
yield e
272272

@@ -365,7 +365,7 @@ async def handle_thinking_delta(self, delta: ThinkingPartDelta) -> AsyncIterator
365365
Yields:
366366
Protocol-specific events.
367367
"""
368-
return # TODO (DouweM): coverage
368+
return # pragma: no cover
369369
yield # Make this an async generator
370370

371371
async def handle_thinking_end(

tests/test_ui.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,11 @@ async def stream_function(
223223
yield 'some text'
224224
yield {5: DeltaThinkingPart(content='More thinking')}
225225

226-
agent = Agent(model=FunctionModel(stream_function=stream_function), deps_type=UIDeps)
226+
agent = Agent(model=FunctionModel(stream_function=stream_function))
227227

228228
request = UIRequest(messages=[ModelRequest.user_text_prompt('Tell me about Hello World')])
229-
deps = UIDeps(state=UIState())
230-
231229
adapter = UIAdapter(agent, request)
232-
events = [event async for event in adapter.run_stream(deps=deps)]
230+
events = [event async for event in adapter.run_stream()]
233231

234232
assert events == snapshot(
235233
[
@@ -290,13 +288,11 @@ async def stream_function(
290288
}
291289
yield 'A "Hello, World!" program is usually a simple computer program that emits (or displays) to the screen (often the console) a message similar to "Hello, World!". '
292290

293-
agent = Agent(model=FunctionModel(stream_function=stream_function), deps_type=UIDeps)
291+
agent = Agent(model=FunctionModel(stream_function=stream_function))
294292

295293
request = UIRequest(messages=[ModelRequest.user_text_prompt('Tell me about Hello World')])
296-
deps = UIDeps(state=UIState())
297-
298294
adapter = UIAdapter(agent, request)
299-
events = [event async for event in adapter.run_stream(deps=deps)]
295+
events = [event async for event in adapter.run_stream()]
300296

301297
assert events == snapshot(
302298
[
@@ -337,7 +333,7 @@ async def stream_function(
337333
else:
338334
yield 'A "Hello, World!" program is usually a simple computer program that emits (or displays) to the screen (often the console) a message similar to "Hello, World!". '
339335

340-
agent = Agent(model=FunctionModel(stream_function=stream_function), deps_type=UIDeps)
336+
agent = Agent(model=FunctionModel(stream_function=stream_function))
341337

342338
@agent.tool_plain
343339
async def web_search(query: str) -> dict[str, list[dict[str, str]]]:
@@ -351,10 +347,8 @@ async def web_search(query: str) -> dict[str, list[dict[str, str]]]:
351347
}
352348

353349
request = UIRequest(messages=[ModelRequest.user_text_prompt('Tell me about Hello World')])
354-
deps = UIDeps(state=UIState())
355-
356350
adapter = UIAdapter(agent, request)
357-
events = [event async for event in adapter.run_stream(deps=deps)]
351+
events = [event async for event in adapter.run_stream()]
358352

359353
assert events == snapshot(
360354
[
@@ -379,15 +373,13 @@ async def web_search(query: str) -> dict[str, list[dict[str, str]]]:
379373
)
380374

381375

382-
async def test_run_stream_file():
383-
agent = Agent(model=TestModel(), deps_type=UIDeps)
384-
376+
async def test_event_stream_file():
385377
async def event_generator():
386378
yield PartStartEvent(index=0, part=FilePart(content=BinaryImage(data=b'fake', media_type='image/png')))
387379

388-
request = UIRequest(messages=[ModelRequest.user_text_prompt('Generate an image')])
389-
adapter = UIAdapter(agent, request)
390-
events = [event async for event in adapter.process_stream(event_generator())]
380+
request = UIRequest(messages=[ModelRequest.user_text_prompt('Hello')])
381+
event_stream = UIEventStream(request=request)
382+
events = [event async for event in event_stream.handle_stream(event_generator())]
391383

392384
assert events == snapshot(
393385
[
@@ -458,13 +450,11 @@ def web_search(query: str) -> dict[str, list[dict[str, str]]]:
458450
]
459451
}
460452

461-
agent = Agent(model=FunctionModel(stream_function=stream_function), deps_type=UIDeps, output_type=web_search)
453+
agent = Agent(model=FunctionModel(stream_function=stream_function), output_type=web_search)
462454

463455
request = UIRequest(messages=[ModelRequest.user_text_prompt('Tell me about Hello World')])
464-
deps = UIDeps(state=UIState())
465-
466456
adapter = UIAdapter(agent, request)
467-
events = [event async for event in adapter.run_stream(deps=deps)]
457+
events = [event async for event in adapter.run_stream()]
468458

469459
assert events == snapshot(
470460
[
@@ -494,13 +484,11 @@ async def stream_function(
494484
)
495485
}
496486

497-
agent = Agent(model=FunctionModel(stream_function=stream_function), deps_type=UIDeps)
487+
agent = Agent(model=FunctionModel(stream_function=stream_function))
498488

499489
request = UIRequest(messages=[ModelRequest.user_text_prompt('Tell me about Hello World')])
500-
deps = UIDeps(state=UIState())
501-
502490
adapter = UIAdapter(agent, request)
503-
events = [event async for event in adapter.run_stream(deps=deps)]
491+
events = [event async for event in adapter.run_stream()]
504492

505493
assert events == snapshot(
506494
[
@@ -524,17 +512,15 @@ async def stream_function(
524512

525513

526514
async def test_run_stream_request_error():
527-
agent = Agent(model=TestModel(), deps_type=UIDeps)
515+
agent = Agent(model=TestModel())
528516

529517
@agent.tool_plain
530518
async def tool(query: str) -> str:
531519
raise ValueError('Unknown tool')
532520

533521
request = UIRequest(messages=[ModelRequest.user_text_prompt('Hello')])
534-
deps = UIDeps(state=UIState())
535-
536522
adapter = UIAdapter(agent, request)
537-
events = [event async for event in adapter.run_stream(deps=deps)]
523+
events = [event async for event in adapter.run_stream()]
538524

539525
assert events == snapshot(
540526
[
@@ -553,16 +539,15 @@ async def tool(query: str) -> str:
553539

554540

555541
async def test_run_stream_on_complete_error():
556-
agent = Agent(model=TestModel(), deps_type=UIDeps)
542+
agent = Agent(model=TestModel())
557543

558544
request = UIRequest(messages=[ModelRequest.user_text_prompt('Hello')])
559-
deps = UIDeps(state=UIState())
560545

561546
def raise_error(run_result: AgentRunResult[Any]) -> None:
562547
raise ValueError('Faulty on_complete')
563548

564549
adapter = UIAdapter(agent, request)
565-
events = [event async for event in adapter.run_stream(deps=deps, on_complete=raise_error)]
550+
events = [event async for event in adapter.run_stream(on_complete=raise_error)]
566551

567552
assert events == snapshot(
568553
[
@@ -583,16 +568,15 @@ def raise_error(run_result: AgentRunResult[Any]) -> None:
583568

584569

585570
async def test_run_stream_on_complete():
586-
agent = Agent(model=TestModel(), deps_type=UIDeps)
571+
agent = Agent(model=TestModel())
587572

588573
request = UIRequest(messages=[ModelRequest.user_text_prompt('Hello')])
589-
deps = UIDeps(state=UIState())
590574

591575
async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[str]:
592576
yield '<custom>'
593577

594578
adapter = UIAdapter(agent, request)
595-
events = [event async for event in adapter.run_stream(deps=deps, on_complete=on_complete)]
579+
events = [event async for event in adapter.run_stream(on_complete=on_complete)]
596580

597581
assert events == snapshot(
598582
[
@@ -616,7 +600,6 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[str]:
616600
@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed')
617601
async def test_adapter_dispatch_request():
618602
agent = Agent(model=TestModel())
619-
620603
request = UIRequest(messages=[ModelRequest.user_text_prompt('Hello')])
621604

622605
async def receive() -> dict[str, Any]:

0 commit comments

Comments
 (0)