5
5
import sys
6
6
import urllib .parse
7
7
from hashlib import sha256
8
- from os import PathLike
9
8
from pathlib import Path
10
9
from typing import (
11
10
Any ,
12
11
Callable ,
13
12
Dict ,
13
+ Generic ,
14
14
List ,
15
15
Literal ,
16
16
Mapping ,
69
69
)
70
70
71
71
T = TypeVar ('T' )
72
+ CT = TypeVar ('CT' , httpx .Client , httpx .AsyncClient )
73
+ AnyCallable = Callable [..., Any ]
72
74
73
75
74
- class BaseClient :
76
+ class BaseClient ( Generic [ CT ]) :
75
77
def __init__ (
76
78
self ,
77
- client ,
79
+ client : Type [ CT ] ,
78
80
host : Optional [str ] = None ,
79
81
* ,
80
82
follow_redirects : bool = True ,
@@ -90,7 +92,7 @@ def __init__(
90
92
`kwargs` are passed to the httpx client.
91
93
"""
92
94
93
- self ._client = client (
95
+ self ._client : CT = client (
94
96
base_url = _parse_host (host or os .getenv ('OLLAMA_HOST' )),
95
97
follow_redirects = follow_redirects ,
96
98
timeout = timeout ,
@@ -111,7 +113,7 @@ def __init__(
111
113
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
112
114
113
115
114
- class Client (BaseClient ):
116
+ class Client (BaseClient [ httpx . Client ] ):
115
117
def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
116
118
super ().__init__ (httpx .Client , host , ** kwargs )
117
119
@@ -139,19 +141,10 @@ def _request(
139
141
self ,
140
142
cls : Type [T ],
141
143
* args ,
142
- stream : Literal [True ] = True ,
144
+ stream : Literal [True ],
143
145
** kwargs ,
144
146
) -> Iterator [T ]: ...
145
147
146
- @overload
147
- def _request (
148
- self ,
149
- cls : Type [T ],
150
- * args ,
151
- stream : bool = False ,
152
- ** kwargs ,
153
- ) -> Union [T , Iterator [T ]]: ...
154
-
155
148
def _request (
156
149
self ,
157
150
cls : Type [T ],
@@ -189,7 +182,7 @@ def generate(
189
182
system : str = '' ,
190
183
template : str = '' ,
191
184
context : Optional [Sequence [int ]] = None ,
192
- stream : Literal [False ] = False ,
185
+ stream : Literal [False ],
193
186
think : Optional [bool ] = None ,
194
187
raw : bool = False ,
195
188
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -272,7 +265,7 @@ def chat(
272
265
model : str = '' ,
273
266
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
274
267
* ,
275
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
268
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
276
269
stream : Literal [False ] = False ,
277
270
think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
278
271
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -286,8 +279,8 @@ def chat(
286
279
model : str = '' ,
287
280
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
288
281
* ,
289
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
290
- stream : Literal [True ] = True ,
282
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
283
+ stream : Literal [True ],
291
284
think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
292
285
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
293
286
options : Optional [Union [Mapping [str , Any ], Options ]] = None ,
@@ -299,7 +292,7 @@ def chat(
299
292
model : str = '' ,
300
293
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
301
294
* ,
302
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
295
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
303
296
stream : bool = False ,
304
297
think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
305
298
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -414,7 +407,7 @@ def pull(
414
407
model : str ,
415
408
* ,
416
409
insecure : bool = False ,
417
- stream : Literal [True ] = True ,
410
+ stream : Literal [True ],
418
411
) -> Iterator [ProgressResponse ]: ...
419
412
420
413
def pull (
@@ -447,7 +440,7 @@ def push(
447
440
model : str ,
448
441
* ,
449
442
insecure : bool = False ,
450
- stream : Literal [False ] = False ,
443
+ stream : Literal [False ],
451
444
) -> ProgressResponse : ...
452
445
453
446
@overload
@@ -497,7 +490,7 @@ def create(
497
490
parameters : Optional [Union [Mapping [str , Any ], Options ]] = None ,
498
491
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
499
492
* ,
500
- stream : Literal [False ] = False ,
493
+ stream : Literal [False ],
501
494
) -> ProgressResponse : ...
502
495
503
496
@overload
@@ -623,7 +616,7 @@ def ps(self) -> ProcessResponse:
623
616
)
624
617
625
618
626
- class AsyncClient (BaseClient ):
619
+ class AsyncClient (BaseClient [ httpx . AsyncClient ] ):
627
620
def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
628
621
super ().__init__ (httpx .AsyncClient , host , ** kwargs )
629
622
@@ -783,7 +776,7 @@ async def chat(
783
776
model : str = '' ,
784
777
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
785
778
* ,
786
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
779
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
787
780
stream : Literal [False ] = False ,
788
781
think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
789
782
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -797,7 +790,7 @@ async def chat(
797
790
model : str = '' ,
798
791
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
799
792
* ,
800
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
793
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
801
794
stream : Literal [True ] = True ,
802
795
think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
803
796
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -810,7 +803,7 @@ async def chat(
810
803
model : str = '' ,
811
804
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
812
805
* ,
813
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
806
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
814
807
stream : bool = False ,
815
808
think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
816
809
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -1155,21 +1148,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
1155
1148
)
1156
1149
1157
1150
1158
- def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ) -> Iterator [Tool ]:
1151
+ def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ) -> Iterator [Tool ]:
1159
1152
for unprocessed_tool in tools or []:
1160
1153
yield convert_function_to_tool (unprocessed_tool ) if callable (unprocessed_tool ) else Tool .model_validate (unprocessed_tool )
1161
1154
1162
1155
1163
- def _as_path (s : Optional [Union [str , PathLike ]]) -> Union [Path , None ]:
1164
- if isinstance (s , (str , Path )):
1165
- try :
1166
- if (p := Path (s )).exists ():
1167
- return p
1168
- except Exception :
1169
- ...
1170
- return None
1171
-
1172
-
1173
1156
def _parse_host (host : Optional [str ]) -> str :
1174
1157
"""
1175
1158
>>> _parse_host(None)
0 commit comments