5
5
import sys
6
6
import urllib .parse
7
7
from hashlib import sha256
8
- from os import PathLike
8
+
9
9
from pathlib import Path
10
10
from typing import (
11
11
Any ,
12
12
Callable ,
13
13
Dict ,
14
+ Generic ,
14
15
List ,
15
16
Literal ,
16
17
Mapping ,
22
23
overload ,
23
24
)
24
25
26
+ AnyCallable = Callable [..., Any ]
27
+
25
28
import anyio
26
29
from pydantic .json_schema import JsonSchemaValue
27
30
69
72
)
70
73
71
74
T = TypeVar ('T' )
75
+ CT = TypeVar ('CT' , httpx .Client , httpx .AsyncClient )
72
76
73
77
74
- class BaseClient :
78
+ class BaseClient ( Generic [ CT ]) :
75
79
def __init__ (
76
80
self ,
77
- client ,
81
+ client : Type [ CT ] ,
78
82
host : Optional [str ] = None ,
79
83
* ,
80
84
follow_redirects : bool = True ,
@@ -90,7 +94,7 @@ def __init__(
90
94
`kwargs` are passed to the httpx client.
91
95
"""
92
96
93
- self ._client = client (
97
+ self ._client : CT = client (
94
98
base_url = _parse_host (host or os .getenv ('OLLAMA_HOST' )),
95
99
follow_redirects = follow_redirects ,
96
100
timeout = timeout ,
@@ -111,7 +115,7 @@ def __init__(
111
115
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
112
116
113
117
114
- class Client (BaseClient ):
118
+ class Client (BaseClient [ httpx . Client ] ):
115
119
def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
116
120
super ().__init__ (httpx .Client , host , ** kwargs )
117
121
@@ -139,19 +143,10 @@ def _request(
139
143
self ,
140
144
cls : Type [T ],
141
145
* args ,
142
- stream : Literal [True ] = True ,
146
+ stream : Literal [True ],
143
147
** kwargs ,
144
148
) -> Iterator [T ]: ...
145
149
146
- @overload
147
- def _request (
148
- self ,
149
- cls : Type [T ],
150
- * args ,
151
- stream : bool = False ,
152
- ** kwargs ,
153
- ) -> Union [T , Iterator [T ]]: ...
154
-
155
150
def _request (
156
151
self ,
157
152
cls : Type [T ],
@@ -189,7 +184,7 @@ def generate(
189
184
system : str = '' ,
190
185
template : str = '' ,
191
186
context : Optional [Sequence [int ]] = None ,
192
- stream : Literal [False ] = False ,
187
+ stream : Literal [False ],
193
188
think : Optional [bool ] = None ,
194
189
raw : bool = False ,
195
190
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -272,7 +267,7 @@ def chat(
272
267
model : str = '' ,
273
268
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
274
269
* ,
275
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
270
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
276
271
stream : Literal [False ] = False ,
277
272
think : Optional [bool ] = None ,
278
273
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -286,8 +281,8 @@ def chat(
286
281
model : str = '' ,
287
282
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
288
283
* ,
289
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
290
- stream : Literal [True ] = True ,
284
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
285
+ stream : Literal [True ],
291
286
think : Optional [bool ] = None ,
292
287
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
293
288
options : Optional [Union [Mapping [str , Any ], Options ]] = None ,
@@ -299,7 +294,7 @@ def chat(
299
294
model : str = '' ,
300
295
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
301
296
* ,
302
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
297
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
303
298
stream : bool = False ,
304
299
think : Optional [bool ] = None ,
305
300
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -414,7 +409,7 @@ def pull(
414
409
model : str ,
415
410
* ,
416
411
insecure : bool = False ,
417
- stream : Literal [True ] = True ,
412
+ stream : Literal [True ],
418
413
) -> Iterator [ProgressResponse ]: ...
419
414
420
415
def pull (
@@ -447,7 +442,7 @@ def push(
447
442
model : str ,
448
443
* ,
449
444
insecure : bool = False ,
450
- stream : Literal [False ] = False ,
445
+ stream : Literal [False ],
451
446
) -> ProgressResponse : ...
452
447
453
448
@overload
@@ -497,7 +492,7 @@ def create(
497
492
parameters : Optional [Union [Mapping [str , Any ], Options ]] = None ,
498
493
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
499
494
* ,
500
- stream : Literal [False ] = False ,
495
+ stream : Literal [False ],
501
496
) -> ProgressResponse : ...
502
497
503
498
@overload
@@ -623,7 +618,7 @@ def ps(self) -> ProcessResponse:
623
618
)
624
619
625
620
626
- class AsyncClient (BaseClient ):
621
+ class AsyncClient (BaseClient [ httpx . AsyncClient ] ):
627
622
def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
628
623
super ().__init__ (httpx .AsyncClient , host , ** kwargs )
629
624
@@ -783,7 +778,7 @@ async def chat(
783
778
model : str = '' ,
784
779
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
785
780
* ,
786
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
781
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
787
782
stream : Literal [False ] = False ,
788
783
think : Optional [bool ] = None ,
789
784
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -797,7 +792,7 @@ async def chat(
797
792
model : str = '' ,
798
793
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
799
794
* ,
800
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
795
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
801
796
stream : Literal [True ] = True ,
802
797
think : Optional [bool ] = None ,
803
798
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -810,7 +805,7 @@ async def chat(
810
805
model : str = '' ,
811
806
messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
812
807
* ,
813
- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
808
+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
814
809
stream : bool = False ,
815
810
think : Optional [bool ] = None ,
816
811
format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -1155,21 +1150,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
1155
1150
)
1156
1151
1157
1152
1158
- def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ) -> Iterator [Tool ]:
1153
+ def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ) -> Iterator [Tool ]:
1159
1154
for unprocessed_tool in tools or []:
1160
1155
yield convert_function_to_tool (unprocessed_tool ) if callable (unprocessed_tool ) else Tool .model_validate (unprocessed_tool )
1161
1156
1162
1157
1163
- def _as_path (s : Optional [Union [str , PathLike ]]) -> Union [Path , None ]:
1164
- if isinstance (s , (str , Path )):
1165
- try :
1166
- if (p := Path (s )).exists ():
1167
- return p
1168
- except Exception :
1169
- ...
1170
- return None
1171
-
1172
-
1173
1158
def _parse_host (host : Optional [str ]) -> str :
1174
1159
"""
1175
1160
>>> _parse_host(None)
0 commit comments