@@ -48,11 +48,18 @@ def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
48
48
return GroqModel ('llama-3.1-70b-versatile' , http_client = http_client )
49
49
50
50
51
+ def ollama (http_client : httpx .AsyncClient , _tmp_path : Path ) -> Model :
52
+ from pydantic_ai .models .ollama import OllamaModel
53
+
54
+ return OllamaModel ('qwen:0.5b' , http_client = http_client )
55
+
56
+
51
57
params = [
52
58
pytest .param (openai , id = 'openai' ),
53
59
pytest .param (gemini , id = 'gemini' ),
54
60
pytest .param (vertexai , id = 'vertexai' ),
55
61
pytest .param (groq , id = 'groq' ),
62
+ pytest .param (ollama , id = 'ollama' ),
56
63
]
57
64
GetModel = Callable [[httpx .AsyncClient , Path ], Model ]
58
65
@@ -83,14 +90,18 @@ async def test_stream(http_client: httpx.AsyncClient, tmp_path: Path, get_model:
83
90
assert 'paris' in data .lower ()
84
91
print ('Stream cost:' , result .cost ())
85
92
cost = result .cost ()
86
- assert cost .total_tokens is not None and cost .total_tokens > 0
93
+ if get_model .__name__ != 'ollama' :
94
+ assert cost .total_tokens is not None and cost .total_tokens > 0
87
95
88
96
89
97
class MyModel (BaseModel ):
90
98
city : str
91
99
92
100
93
- @pytest .mark .parametrize ('get_model' , params )
101
+ structured_params = [p for p in params if p .id != 'ollama' ]
102
+
103
+
104
+ @pytest .mark .parametrize ('get_model' , structured_params )
94
105
async def test_structured (http_client : httpx .AsyncClient , tmp_path : Path , get_model : GetModel ):
95
106
agent = Agent (get_model (http_client , tmp_path ), result_type = MyModel )
96
107
result = await agent .run ('What is the capital of the UK?' )
0 commit comments