diff --git a/ollama/_client.py b/ollama/_client.py index 0a85a74..ef790b7 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -5,12 +5,13 @@ import sys import urllib.parse from hashlib import sha256 -from os import PathLike + from pathlib import Path from typing import ( Any, Callable, Dict, + Generic, List, Literal, Mapping, @@ -22,6 +23,8 @@ overload, ) +AnyCallable = Callable[..., Any] + import anyio from pydantic.json_schema import JsonSchemaValue @@ -69,12 +72,13 @@ ) T = TypeVar('T') +CT = TypeVar('CT', httpx.Client, httpx.AsyncClient) -class BaseClient: +class BaseClient(Generic[CT]): def __init__( self, - client, + client: Type[CT], host: Optional[str] = None, *, follow_redirects: bool = True, @@ -90,7 +94,7 @@ def __init__( `kwargs` are passed to the httpx client. """ - self._client = client( + self._client: CT = client( base_url=_parse_host(host or os.getenv('OLLAMA_HOST')), follow_redirects=follow_redirects, timeout=timeout, @@ -111,7 +115,7 @@ def __init__( CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' -class Client(BaseClient): +class Client(BaseClient[httpx.Client]): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.Client, host, **kwargs) @@ -139,19 +143,10 @@ def _request( self, cls: Type[T], *args, - stream: Literal[True] = True, + stream: Literal[True], **kwargs, ) -> Iterator[T]: ... - @overload - def _request( - self, - cls: Type[T], - *args, - stream: bool = False, - **kwargs, - ) -> Union[T, Iterator[T]]: ... - def _request( self, cls: Type[T], @@ -189,7 +184,7 @@ def generate( system: str = '', template: str = '', context: Optional[Sequence[int]] = None, - stream: Literal[False] = False, + stream: Literal[False], think: Optional[bool] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, @@ -272,7 +267,7 @@ def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, @@ -286,8 +281,8 @@ def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, - stream: Literal[True] = True, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None, + stream: Literal[True], think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, @@ -299,7 +294,7 @@ def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, @@ -414,7 +409,7 @@ def pull( model: str, *, insecure: bool = False, - stream: Literal[True] = True, + stream: Literal[True], ) -> Iterator[ProgressResponse]: ... def pull( @@ -447,7 +442,7 @@ def push( model: str, *, insecure: bool = False, - stream: Literal[False] = False, + stream: Literal[False], ) -> ProgressResponse: ... @overload @@ -497,7 +492,7 @@ def create( parameters: Optional[Union[Mapping[str, Any], Options]] = None, messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - stream: Literal[False] = False, + stream: Literal[False], ) -> ProgressResponse: ... @overload @@ -623,7 +618,7 @@ def ps(self) -> ProcessResponse: ) -class AsyncClient(BaseClient): +class AsyncClient(BaseClient[httpx.AsyncClient]): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.AsyncClient, host, **kwargs) @@ -783,7 +778,7 @@ async def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, @@ -797,7 +792,7 @@ async def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, @@ -810,7 +805,7 @@ async def chat( model: str = '', messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, *, - tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, @@ -1155,21 +1150,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message] ) -def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]: +def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, AnyCallable]]] = None) -> Iterator[Tool]: for unprocessed_tool in tools or []: yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool) -def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: - if isinstance(s, (str, Path)): - try: - if (p := Path(s)).exists(): - return p - except Exception: - ... - return None - - def _parse_host(host: Optional[str]) -> str: """ >>> _parse_host(None) diff --git a/ollama/_utils.py b/ollama/_utils.py index 15f1cc0..c1b7549 100644 --- a/ollama/_utils.py +++ b/ollama/_utils.py @@ -3,12 +3,14 @@ import inspect import re from collections import defaultdict -from typing import Callable, Union +from typing import Any, Callable, Union import pydantic from ollama._types import Tool +AnyCallable = Callable[..., Any] + def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: parsed_docstring = defaultdict(str) @@ -53,7 +55,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: return parsed_docstring -def convert_function_to_tool(func: Callable) -> Tool: +def convert_function_to_tool(func: AnyCallable) -> Tool: doc_string_hash = str(hash(inspect.getdoc(func))) parsed_docstring = _parse_docstring(inspect.getdoc(func)) schema = type(