Skip to content

Commit 57c026c

Browse files
committed
refactor: Fix type annotation errors and remove unused code
1 parent b0f6b99 commit 57c026c

File tree

2 files changed

+25
-40
lines changed

2 files changed

+25
-40
lines changed

ollama/_client.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import sys
66
import urllib.parse
77
from hashlib import sha256
8-
from os import PathLike
98
from pathlib import Path
109
from typing import (
1110
Any,
1211
Callable,
1312
Dict,
13+
Generic,
1414
List,
1515
Literal,
1616
Mapping,
@@ -69,12 +69,14 @@
6969
)
7070

7171
T = TypeVar('T')
72+
CT = TypeVar('CT', httpx.Client, httpx.AsyncClient)
73+
AnyCallable = Callable[..., Any]
7274

7375

74-
class BaseClient:
76+
class BaseClient(Generic[CT]):
7577
def __init__(
7678
self,
77-
client,
79+
client: Type[CT],
7880
host: Optional[str] = None,
7981
*,
8082
follow_redirects: bool = True,
@@ -90,7 +92,7 @@ def __init__(
9092
`kwargs` are passed to the httpx client.
9193
"""
9294

93-
self._client = client(
95+
self._client: CT = client(
9496
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
9597
follow_redirects=follow_redirects,
9698
timeout=timeout,
@@ -111,7 +113,7 @@ def __init__(
111113
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
112114

113115

114-
class Client(BaseClient):
116+
class Client(BaseClient[httpx.Client]):
115117
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
116118
super().__init__(httpx.Client, host, **kwargs)
117119

@@ -139,19 +141,10 @@ def _request(
139141
self,
140142
cls: Type[T],
141143
*args,
142-
stream: Literal[True] = True,
144+
stream: Literal[True],
143145
**kwargs,
144146
) -> Iterator[T]: ...
145147

146-
@overload
147-
def _request(
148-
self,
149-
cls: Type[T],
150-
*args,
151-
stream: bool = False,
152-
**kwargs,
153-
) -> Union[T, Iterator[T]]: ...
154-
155148
def _request(
156149
self,
157150
cls: Type[T],
@@ -189,7 +182,7 @@ def generate(
189182
system: str = '',
190183
template: str = '',
191184
context: Optional[Sequence[int]] = None,
192-
stream: Literal[False] = False,
185+
stream: Literal[False],
193186
think: Optional[bool] = None,
194187
raw: bool = False,
195188
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -272,7 +265,7 @@ def chat(
272265
model: str = '',
273266
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
274267
*,
275-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
268+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
276269
stream: Literal[False] = False,
277270
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
278271
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -286,8 +279,8 @@ def chat(
286279
model: str = '',
287280
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
288281
*,
289-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
290-
stream: Literal[True] = True,
282+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
283+
stream: Literal[True],
291284
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
292285
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
293286
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -299,7 +292,7 @@ def chat(
299292
model: str = '',
300293
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
301294
*,
302-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
295+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
303296
stream: bool = False,
304297
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
305298
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -414,7 +407,7 @@ def pull(
414407
model: str,
415408
*,
416409
insecure: bool = False,
417-
stream: Literal[True] = True,
410+
stream: Literal[True],
418411
) -> Iterator[ProgressResponse]: ...
419412

420413
def pull(
@@ -447,7 +440,7 @@ def push(
447440
model: str,
448441
*,
449442
insecure: bool = False,
450-
stream: Literal[False] = False,
443+
stream: Literal[False],
451444
) -> ProgressResponse: ...
452445

453446
@overload
@@ -497,7 +490,7 @@ def create(
497490
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
498491
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
499492
*,
500-
stream: Literal[False] = False,
493+
stream: Literal[False],
501494
) -> ProgressResponse: ...
502495

503496
@overload
@@ -623,7 +616,7 @@ def ps(self) -> ProcessResponse:
623616
)
624617

625618

626-
class AsyncClient(BaseClient):
619+
class AsyncClient(BaseClient[httpx.AsyncClient]):
627620
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
628621
super().__init__(httpx.AsyncClient, host, **kwargs)
629622

@@ -783,7 +776,7 @@ async def chat(
783776
model: str = '',
784777
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
785778
*,
786-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
779+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
787780
stream: Literal[False] = False,
788781
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
789782
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -797,7 +790,7 @@ async def chat(
797790
model: str = '',
798791
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
799792
*,
800-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
793+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
801794
stream: Literal[True] = True,
802795
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
803796
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -810,7 +803,7 @@ async def chat(
810803
model: str = '',
811804
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
812805
*,
813-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
806+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
814807
stream: bool = False,
815808
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
816809
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -1155,21 +1148,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
11551148
)
11561149

11571150

1158-
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
1151+
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None) -> Iterator[Tool]:
11591152
for unprocessed_tool in tools or []:
11601153
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
11611154

11621155

1163-
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
1164-
if isinstance(s, (str, Path)):
1165-
try:
1166-
if (p := Path(s)).exists():
1167-
return p
1168-
except Exception:
1169-
...
1170-
return None
1171-
1172-
11731156
def _parse_host(host: Optional[str]) -> str:
11741157
"""
11751158
>>> _parse_host(None)

ollama/_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import inspect
44
import re
55
from collections import defaultdict
6-
from typing import Callable, Union
6+
from typing import Any, Callable, Union
77

88
import pydantic
99

1010
from ollama._types import Tool
1111

12+
AnyCallable = Callable[..., Any]
13+
1214

1315
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
1416
parsed_docstring = defaultdict(str)
@@ -53,7 +55,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
5355
return parsed_docstring
5456

5557

56-
def convert_function_to_tool(func: Callable) -> Tool:
58+
def convert_function_to_tool(func: AnyCallable) -> Tool:
5759
doc_string_hash = str(hash(inspect.getdoc(func)))
5860
parsed_docstring = _parse_docstring(inspect.getdoc(func))
5961
schema = type(

0 commit comments

Comments
 (0)