Skip to content

Commit ce56f27

Browse files
Add type overloads to methods (#181)
* Add type overloads for chat() method in _client.py * Overloading * Fix Overload Overlap * Fix async chat * Lint * Reverse --------- Co-authored-by: Simon Ottenhaus <[email protected]>
1 parent 982d65f commit ce56f27

File tree

1 file changed

+213
-1
lines changed

1 file changed

+213
-1
lines changed

ollama/_client.py

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from hashlib import sha256
1212
from base64 import b64encode, b64decode
1313

14-
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal
14+
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal, overload
1515

1616
import sys
1717

@@ -97,6 +97,38 @@ def _request_stream(
9797
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
9898
return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json()
9999

100+
@overload
101+
def generate(
102+
self,
103+
model: str = '',
104+
prompt: str = '',
105+
system: str = '',
106+
template: str = '',
107+
context: Optional[Sequence[int]] = None,
108+
stream: Literal[False] = False,
109+
raw: bool = False,
110+
format: Literal['', 'json'] = '',
111+
images: Optional[Sequence[AnyStr]] = None,
112+
options: Optional[Options] = None,
113+
keep_alive: Optional[Union[float, str]] = None,
114+
) -> Mapping[str, Any]: ...
115+
116+
@overload
117+
def generate(
118+
self,
119+
model: str = '',
120+
prompt: str = '',
121+
system: str = '',
122+
template: str = '',
123+
context: Optional[Sequence[int]] = None,
124+
stream: Literal[True] = True,
125+
raw: bool = False,
126+
format: Literal['', 'json'] = '',
127+
images: Optional[Sequence[AnyStr]] = None,
128+
options: Optional[Options] = None,
129+
keep_alive: Optional[Union[float, str]] = None,
130+
) -> Iterator[Mapping[str, Any]]: ...
131+
100132
def generate(
101133
self,
102134
model: str = '',
@@ -143,6 +175,28 @@ def generate(
143175
stream=stream,
144176
)
145177

178+
@overload
179+
def chat(
180+
self,
181+
model: str = '',
182+
messages: Optional[Sequence[Message]] = None,
183+
stream: Literal[False] = False,
184+
format: Literal['', 'json'] = '',
185+
options: Optional[Options] = None,
186+
keep_alive: Optional[Union[float, str]] = None,
187+
) -> Mapping[str, Any]: ...
188+
189+
@overload
190+
def chat(
191+
self,
192+
model: str = '',
193+
messages: Optional[Sequence[Message]] = None,
194+
stream: Literal[True] = True,
195+
format: Literal['', 'json'] = '',
196+
options: Optional[Options] = None,
197+
keep_alive: Optional[Union[float, str]] = None,
198+
) -> Iterator[Mapping[str, Any]]: ...
199+
146200
def chat(
147201
self,
148202
model: str = '',
@@ -209,6 +263,22 @@ def embeddings(
209263
},
210264
).json()
211265

266+
@overload
267+
def pull(
268+
self,
269+
model: str,
270+
insecure: bool = False,
271+
stream: Literal[False] = False,
272+
) -> Mapping[str, Any]: ...
273+
274+
@overload
275+
def pull(
276+
self,
277+
model: str,
278+
insecure: bool = False,
279+
stream: Literal[True] = True,
280+
) -> Iterator[Mapping[str, Any]]: ...
281+
212282
def pull(
213283
self,
214284
model: str,
@@ -231,6 +301,22 @@ def pull(
231301
stream=stream,
232302
)
233303

304+
@overload
305+
def push(
306+
self,
307+
model: str,
308+
insecure: bool = False,
309+
stream: Literal[False] = False,
310+
) -> Mapping[str, Any]: ...
311+
312+
@overload
313+
def push(
314+
self,
315+
model: str,
316+
insecure: bool = False,
317+
stream: Literal[True] = True,
318+
) -> Iterator[Mapping[str, Any]]: ...
319+
234320
def push(
235321
self,
236322
model: str,
@@ -253,6 +339,26 @@ def push(
253339
stream=stream,
254340
)
255341

342+
@overload
343+
def create(
344+
self,
345+
model: str,
346+
path: Optional[Union[str, PathLike]] = None,
347+
modelfile: Optional[str] = None,
348+
quantize: Optional[str] = None,
349+
stream: Literal[False] = False,
350+
) -> Mapping[str, Any]: ...
351+
352+
@overload
353+
def create(
354+
self,
355+
model: str,
356+
path: Optional[Union[str, PathLike]] = None,
357+
modelfile: Optional[str] = None,
358+
quantize: Optional[str] = None,
359+
stream: Literal[True] = True,
360+
) -> Iterator[Mapping[str, Any]]: ...
361+
256362
def create(
257363
self,
258364
model: str,
@@ -386,6 +492,38 @@ async def _request_stream(
386492
response = await self._request(*args, **kwargs)
387493
return response.json()
388494

495+
@overload
496+
async def generate(
497+
self,
498+
model: str = '',
499+
prompt: str = '',
500+
system: str = '',
501+
template: str = '',
502+
context: Optional[Sequence[int]] = None,
503+
stream: Literal[False] = False,
504+
raw: bool = False,
505+
format: Literal['', 'json'] = '',
506+
images: Optional[Sequence[AnyStr]] = None,
507+
options: Optional[Options] = None,
508+
keep_alive: Optional[Union[float, str]] = None,
509+
) -> Mapping[str, Any]: ...
510+
511+
@overload
512+
async def generate(
513+
self,
514+
model: str = '',
515+
prompt: str = '',
516+
system: str = '',
517+
template: str = '',
518+
context: Optional[Sequence[int]] = None,
519+
stream: Literal[True] = True,
520+
raw: bool = False,
521+
format: Literal['', 'json'] = '',
522+
images: Optional[Sequence[AnyStr]] = None,
523+
options: Optional[Options] = None,
524+
keep_alive: Optional[Union[float, str]] = None,
525+
) -> AsyncIterator[Mapping[str, Any]]: ...
526+
389527
async def generate(
390528
self,
391529
model: str = '',
@@ -431,6 +569,28 @@ async def generate(
431569
stream=stream,
432570
)
433571

572+
@overload
573+
async def chat(
574+
self,
575+
model: str = '',
576+
messages: Optional[Sequence[Message]] = None,
577+
stream: Literal[False] = False,
578+
format: Literal['', 'json'] = '',
579+
options: Optional[Options] = None,
580+
keep_alive: Optional[Union[float, str]] = None,
581+
) -> Mapping[str, Any]: ...
582+
583+
@overload
584+
async def chat(
585+
self,
586+
model: str = '',
587+
messages: Optional[Sequence[Message]] = None,
588+
stream: Literal[True] = True,
589+
format: Literal['', 'json'] = '',
590+
options: Optional[Options] = None,
591+
keep_alive: Optional[Union[float, str]] = None,
592+
) -> AsyncIterator[Mapping[str, Any]]: ...
593+
434594
async def chat(
435595
self,
436596
model: str = '',
@@ -498,6 +658,22 @@ async def embeddings(
498658

499659
return response.json()
500660

661+
@overload
662+
async def pull(
663+
self,
664+
model: str,
665+
insecure: bool = False,
666+
stream: Literal[False] = False,
667+
) -> Mapping[str, Any]: ...
668+
669+
@overload
670+
async def pull(
671+
self,
672+
model: str,
673+
insecure: bool = False,
674+
stream: Literal[True] = True,
675+
) -> AsyncIterator[Mapping[str, Any]]: ...
676+
501677
async def pull(
502678
self,
503679
model: str,
@@ -520,6 +696,22 @@ async def pull(
520696
stream=stream,
521697
)
522698

699+
@overload
700+
async def push(
701+
self,
702+
model: str,
703+
insecure: bool = False,
704+
stream: Literal[False] = False,
705+
) -> Mapping[str, Any]: ...
706+
707+
@overload
708+
async def push(
709+
self,
710+
model: str,
711+
insecure: bool = False,
712+
stream: Literal[True] = True,
713+
) -> AsyncIterator[Mapping[str, Any]]: ...
714+
523715
async def push(
524716
self,
525717
model: str,
@@ -542,6 +734,26 @@ async def push(
542734
stream=stream,
543735
)
544736

737+
@overload
738+
async def create(
739+
self,
740+
model: str,
741+
path: Optional[Union[str, PathLike]] = None,
742+
modelfile: Optional[str] = None,
743+
quantize: Optional[str] = None,
744+
stream: Literal[False] = False,
745+
) -> Mapping[str, Any]: ...
746+
747+
@overload
748+
async def create(
749+
self,
750+
model: str,
751+
path: Optional[Union[str, PathLike]] = None,
752+
modelfile: Optional[str] = None,
753+
quantize: Optional[str] = None,
754+
stream: Literal[True] = True,
755+
) -> AsyncIterator[Mapping[str, Any]]: ...
756+
545757
async def create(
546758
self,
547759
model: str,

0 commit comments

Comments
 (0)