Skip to content

Commit 3604002

Browse files
committed
refactor: Fix type annotation errors and remove unused code
1 parent b0f6b99 commit 3604002

File tree

2 files changed

+27
-40
lines changed

2 files changed

+27
-40
lines changed

ollama/_client.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import sys
66
import urllib.parse
77
from hashlib import sha256
8-
from os import PathLike
8+
99
from pathlib import Path
1010
from typing import (
1111
Any,
1212
Callable,
1313
Dict,
14+
Generic,
1415
List,
1516
Literal,
1617
Mapping,
@@ -22,6 +23,8 @@
2223
overload,
2324
)
2425

26+
AnyCallable = Callable[..., Any]
27+
2528
import anyio
2629
from pydantic.json_schema import JsonSchemaValue
2730

@@ -69,12 +72,13 @@
6972
)
7073

7174
T = TypeVar('T')
75+
CT = TypeVar('CT', httpx.Client, httpx.AsyncClient)
7276

7377

74-
class BaseClient:
78+
class BaseClient(Generic[CT]):
7579
def __init__(
7680
self,
77-
client,
81+
client: Type[CT],
7882
host: Optional[str] = None,
7983
*,
8084
follow_redirects: bool = True,
@@ -90,7 +94,7 @@ def __init__(
9094
`kwargs` are passed to the httpx client.
9195
"""
9296

93-
self._client = client(
97+
self._client: CT = client(
9498
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
9599
follow_redirects=follow_redirects,
96100
timeout=timeout,
@@ -111,7 +115,7 @@ def __init__(
111115
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
112116

113117

114-
class Client(BaseClient):
118+
class Client(BaseClient[httpx.Client]):
115119
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
116120
super().__init__(httpx.Client, host, **kwargs)
117121

@@ -139,19 +143,10 @@ def _request(
139143
self,
140144
cls: Type[T],
141145
*args,
142-
stream: Literal[True] = True,
146+
stream: Literal[True],
143147
**kwargs,
144148
) -> Iterator[T]: ...
145149

146-
@overload
147-
def _request(
148-
self,
149-
cls: Type[T],
150-
*args,
151-
stream: bool = False,
152-
**kwargs,
153-
) -> Union[T, Iterator[T]]: ...
154-
155150
def _request(
156151
self,
157152
cls: Type[T],
@@ -189,7 +184,7 @@ def generate(
189184
system: str = '',
190185
template: str = '',
191186
context: Optional[Sequence[int]] = None,
192-
stream: Literal[False] = False,
187+
stream: Literal[False],
193188
think: Optional[bool] = None,
194189
raw: bool = False,
195190
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -272,7 +267,7 @@ def chat(
272267
model: str = '',
273268
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
274269
*,
275-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
270+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
276271
stream: Literal[False] = False,
277272
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
278273
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -286,8 +281,8 @@ def chat(
286281
model: str = '',
287282
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
288283
*,
289-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
290-
stream: Literal[True] = True,
284+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
285+
stream: Literal[True],
291286
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
292287
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
293288
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -299,7 +294,7 @@ def chat(
299294
model: str = '',
300295
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
301296
*,
302-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
297+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
303298
stream: bool = False,
304299
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
305300
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -414,7 +409,7 @@ def pull(
414409
model: str,
415410
*,
416411
insecure: bool = False,
417-
stream: Literal[True] = True,
412+
stream: Literal[True],
418413
) -> Iterator[ProgressResponse]: ...
419414

420415
def pull(
@@ -447,7 +442,7 @@ def push(
447442
model: str,
448443
*,
449444
insecure: bool = False,
450-
stream: Literal[False] = False,
445+
stream: Literal[False],
451446
) -> ProgressResponse: ...
452447

453448
@overload
@@ -497,7 +492,7 @@ def create(
497492
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
498493
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
499494
*,
500-
stream: Literal[False] = False,
495+
stream: Literal[False],
501496
) -> ProgressResponse: ...
502497

503498
@overload
@@ -623,7 +618,7 @@ def ps(self) -> ProcessResponse:
623618
)
624619

625620

626-
class AsyncClient(BaseClient):
621+
class AsyncClient(BaseClient[httpx.AsyncClient]):
627622
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
628623
super().__init__(httpx.AsyncClient, host, **kwargs)
629624

@@ -783,7 +778,7 @@ async def chat(
783778
model: str = '',
784779
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
785780
*,
786-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
781+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
787782
stream: Literal[False] = False,
788783
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
789784
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -797,7 +792,7 @@ async def chat(
797792
model: str = '',
798793
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
799794
*,
800-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
795+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
801796
stream: Literal[True] = True,
802797
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
803798
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -810,7 +805,7 @@ async def chat(
810805
model: str = '',
811806
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
812807
*,
813-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
808+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None,
814809
stream: bool = False,
815810
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
816811
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
@@ -1155,21 +1150,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
11551150
)
11561151

11571152

1158-
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
1153+
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None) -> Iterator[Tool]:
11591154
for unprocessed_tool in tools or []:
11601155
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
11611156

11621157

1163-
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
1164-
if isinstance(s, (str, Path)):
1165-
try:
1166-
if (p := Path(s)).exists():
1167-
return p
1168-
except Exception:
1169-
...
1170-
return None
1171-
1172-
11731158
def _parse_host(host: Optional[str]) -> str:
11741159
"""
11751160
>>> _parse_host(None)

ollama/_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import inspect
44
import re
55
from collections import defaultdict
6-
from typing import Callable, Union
6+
from typing import Any, Callable, Union
77

88
import pydantic
99

1010
from ollama._types import Tool
1111

12+
AnyCallable = Callable[..., Any]
13+
1214

1315
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
1416
parsed_docstring = defaultdict(str)
@@ -53,7 +55,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
5355
return parsed_docstring
5456

5557

56-
def convert_function_to_tool(func: Callable) -> Tool:
58+
def convert_function_to_tool(func: AnyCallable) -> Tool:
5759
doc_string_hash = str(hash(inspect.getdoc(func)))
5860
parsed_docstring = _parse_docstring(inspect.getdoc(func))
5961
schema = type(

0 commit comments

Comments
 (0)