Skip to content

Commit 8d0d0e4

Browse files
authored
client: add support for passing in Image type to generate (#408)
1 parent 0561f42 commit 8d0d0e4

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

ollama/_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def generate(
190190
stream: Literal[False] = False,
191191
raw: bool = False,
192192
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
193-
images: Optional[Sequence[Union[str, bytes]]] = None,
193+
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
194194
options: Optional[Union[Mapping[str, Any], Options]] = None,
195195
keep_alive: Optional[Union[float, str]] = None,
196196
) -> GenerateResponse: ...
@@ -208,7 +208,7 @@ def generate(
208208
stream: Literal[True] = True,
209209
raw: bool = False,
210210
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
211-
images: Optional[Sequence[Union[str, bytes]]] = None,
211+
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
212212
options: Optional[Union[Mapping[str, Any], Options]] = None,
213213
keep_alive: Optional[Union[float, str]] = None,
214214
) -> Iterator[GenerateResponse]: ...
@@ -225,7 +225,7 @@ def generate(
225225
stream: bool = False,
226226
raw: Optional[bool] = None,
227227
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
228-
images: Optional[Sequence[Union[str, bytes]]] = None,
228+
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
229229
options: Optional[Union[Mapping[str, Any], Options]] = None,
230230
keep_alive: Optional[Union[float, str]] = None,
231231
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
@@ -694,7 +694,7 @@ async def generate(
694694
stream: Literal[False] = False,
695695
raw: bool = False,
696696
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
697-
images: Optional[Sequence[Union[str, bytes]]] = None,
697+
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
698698
options: Optional[Union[Mapping[str, Any], Options]] = None,
699699
keep_alive: Optional[Union[float, str]] = None,
700700
) -> GenerateResponse: ...
@@ -712,7 +712,7 @@ async def generate(
712712
stream: Literal[True] = True,
713713
raw: bool = False,
714714
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
715-
images: Optional[Sequence[Union[str, bytes]]] = None,
715+
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
716716
options: Optional[Union[Mapping[str, Any], Options]] = None,
717717
keep_alive: Optional[Union[float, str]] = None,
718718
) -> AsyncIterator[GenerateResponse]: ...
@@ -729,7 +729,7 @@ async def generate(
729729
stream: bool = False,
730730
raw: Optional[bool] = None,
731731
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
732-
images: Optional[Sequence[Union[str, bytes]]] = None,
732+
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
733733
options: Optional[Union[Mapping[str, Any], Options]] = None,
734734
keep_alive: Optional[Union[float, str]] = None,
735735
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:

tests/test_client.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from werkzeug.wrappers import Request, Response
1212

1313
from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools
14+
from ollama._types import Image
1415

1516
PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
1617
PNG_BYTES = base64.b64decode(PNG_BASE64)
@@ -286,6 +287,46 @@ def test_client_generate(httpserver: HTTPServer):
286287
assert response['response'] == 'Because it is.'
287288

288289

290+
def test_client_generate_with_image_type(httpserver: HTTPServer):
291+
httpserver.expect_ordered_request(
292+
'/api/generate',
293+
method='POST',
294+
json={
295+
'model': 'dummy',
296+
'prompt': 'What is in this image?',
297+
'stream': False,
298+
'images': [PNG_BASE64],
299+
},
300+
).respond_with_json(
301+
{
302+
'model': 'dummy',
303+
'response': 'A blue sky.',
304+
}
305+
)
306+
307+
client = Client(httpserver.url_for('/'))
308+
response = client.generate('dummy', 'What is in this image?', images=[Image(value=PNG_BASE64)])
309+
assert response['model'] == 'dummy'
310+
assert response['response'] == 'A blue sky.'
311+
312+
313+
def test_client_generate_with_invalid_image(httpserver: HTTPServer):
314+
httpserver.expect_ordered_request(
315+
'/api/generate',
316+
method='POST',
317+
json={
318+
'model': 'dummy',
319+
'prompt': 'What is in this image?',
320+
'stream': False,
321+
'images': ['invalid_base64'],
322+
},
323+
).respond_with_json({'error': 'Invalid image data'}, status=400)
324+
325+
client = Client(httpserver.url_for('/'))
326+
with pytest.raises(ValueError):
327+
client.generate('dummy', 'What is in this image?', images=[Image(value='invalid_base64')])
328+
329+
289330
def test_client_generate_stream(httpserver: HTTPServer):
290331
def stream_handler(_: Request):
291332
def generate():

0 commit comments

Comments
 (0)