Skip to content

Commit f5c8ee0

Browse files
committed
fix async client
1 parent a0388b2 commit f5c8ee0

File tree

2 files changed

+59
-161
lines changed

2 files changed

+59
-161
lines changed

ollama/_client.py

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -546,24 +546,6 @@ def create(
546546
stream=stream,
547547
)
548548

549-
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
550-
base = Path.cwd() if base is None else base
551-
552-
out = io.StringIO()
553-
for line in io.StringIO(modelfile):
554-
command, _, args = line.partition(' ')
555-
if command.upper() not in ['FROM', 'ADAPTER']:
556-
print(line, end='', file=out)
557-
continue
558-
559-
path = Path(args.strip()).expanduser()
560-
path = path if path.is_absolute() else base / path
561-
if path.exists():
562-
args = f'@{self.create_blob(path)}\n'
563-
print(command, args, end='', file=out)
564-
565-
return out.getvalue()
566-
567549
def create_blob(self, path: Union[str, Path]) -> str:
568550
sha256sum = sha256()
569551
with open(path, 'rb') as r:
@@ -996,76 +978,77 @@ async def push(
996978
async def create(
997979
self,
998980
model: str,
999-
path: Optional[Union[str, PathLike]] = None,
1000-
modelfile: Optional[str] = None,
1001-
*,
1002981
quantize: Optional[str] = None,
1003-
stream: Literal[False] = False,
982+
from_: Optional[str] = None,
983+
files: Optional[dict[str, str]] = None,
984+
adapters: Optional[dict[str, str]] = None,
985+
template: Optional[str] = None,
986+
license: Optional[Union[str, list[str]]] = None,
987+
system: Optional[str] = None,
988+
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
989+
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
990+
*,
991+
stream: Literal[True] = True,
1004992
) -> ProgressResponse: ...
1005993

1006994
@overload
1007995
async def create(
1008996
self,
1009997
model: str,
1010-
path: Optional[Union[str, PathLike]] = None,
1011-
modelfile: Optional[str] = None,
1012-
*,
1013998
quantize: Optional[str] = None,
999+
from_: Optional[str] = None,
1000+
files: Optional[dict[str, str]] = None,
1001+
adapters: Optional[dict[str, str]] = None,
1002+
template: Optional[str] = None,
1003+
license: Optional[Union[str, list[str]]] = None,
1004+
system: Optional[str] = None,
1005+
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
1006+
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
1007+
*,
10141008
stream: Literal[True] = True,
10151009
) -> AsyncIterator[ProgressResponse]: ...
10161010

10171011
async def create(
10181012
self,
10191013
model: str,
1020-
path: Optional[Union[str, PathLike]] = None,
1021-
modelfile: Optional[str] = None,
1022-
*,
10231014
quantize: Optional[str] = None,
1015+
from_: Optional[str] = None,
1016+
files: Optional[dict[str, str]] = None,
1017+
adapters: Optional[dict[str, str]] = None,
1018+
template: Optional[str] = None,
1019+
license: Optional[Union[str, list[str]]] = None,
1020+
system: Optional[str] = None,
1021+
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
1022+
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
1023+
*,
10241024
stream: bool = False,
10251025
) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]:
10261026
"""
10271027
Raises `ResponseError` if the request could not be fulfilled.
10281028
10291029
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
10301030
"""
1031-
if (realpath := _as_path(path)) and realpath.exists():
1032-
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
1033-
elif modelfile:
1034-
modelfile = await self._parse_modelfile(modelfile)
1035-
else:
1036-
raise RequestError('must provide either path or modelfile')
10371031

10381032
return await self._request(
10391033
ProgressResponse,
10401034
'POST',
10411035
'/api/create',
10421036
json=CreateRequest(
10431037
model=model,
1044-
modelfile=modelfile,
10451038
stream=stream,
10461039
quantize=quantize,
1040+
from_=from_,
1041+
files=files,
1042+
adapters=adapters,
1043+
license=license,
1044+
template=template,
1045+
system=system,
1046+
parameters=parameters,
1047+
messages=messages,
10471048
).model_dump(exclude_none=True),
10481049
stream=stream,
10491050
)
10501051

1051-
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
1052-
base = Path.cwd() if base is None else base
1053-
1054-
out = io.StringIO()
1055-
for line in io.StringIO(modelfile):
1056-
command, _, args = line.partition(' ')
1057-
if command.upper() not in ['FROM', 'ADAPTER']:
1058-
print(line, end='', file=out)
1059-
continue
1060-
1061-
path = Path(args.strip()).expanduser()
1062-
path = path if path.is_absolute() else base / path
1063-
if path.exists():
1064-
args = f'@{await self.create_blob(path)}\n'
1065-
print(command, args, end='', file=out)
1066-
1067-
return out.getvalue()
1068-
10691052
async def create_blob(self, path: Union[str, Path]) -> str:
10701053
sha256sum = sha256()
10711054
with open(path, 'rb') as r:

tests/test_client.py

Lines changed: 23 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -933,117 +933,39 @@ def generate():
933933

934934

935935
@pytest.mark.asyncio
936-
async def test_async_client_create_path(httpserver: HTTPServer):
937-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
938-
httpserver.expect_ordered_request(
939-
'/api/create',
940-
method='POST',
941-
json={
942-
'model': 'dummy',
943-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
944-
'stream': False,
945-
},
946-
).respond_with_json({'status': 'success'})
947-
948-
client = AsyncClient(httpserver.url_for('/'))
949-
950-
with tempfile.NamedTemporaryFile() as modelfile:
951-
with tempfile.NamedTemporaryFile() as blob:
952-
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
953-
modelfile.flush()
954-
955-
response = await client.create('dummy', path=modelfile.name)
956-
assert response['status'] == 'success'
957-
958-
959-
@pytest.mark.asyncio
960-
async def test_async_client_create_path_relative(httpserver: HTTPServer):
961-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
936+
async def test_async_client_create_with_blob(httpserver: HTTPServer):
962937
httpserver.expect_ordered_request(
963938
'/api/create',
964939
method='POST',
965940
json={
966941
'model': 'dummy',
967-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
968-
'stream': False,
969-
},
970-
).respond_with_json({'status': 'success'})
971-
972-
client = AsyncClient(httpserver.url_for('/'))
973-
974-
with tempfile.NamedTemporaryFile() as modelfile:
975-
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
976-
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
977-
modelfile.flush()
978-
979-
response = await client.create('dummy', path=modelfile.name)
980-
assert response['status'] == 'success'
981-
982-
983-
@pytest.mark.asyncio
984-
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
985-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
986-
httpserver.expect_ordered_request(
987-
'/api/create',
988-
method='POST',
989-
json={
990-
'model': 'dummy',
991-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
992-
'stream': False,
993-
},
994-
).respond_with_json({'status': 'success'})
995-
996-
client = AsyncClient(httpserver.url_for('/'))
997-
998-
with tempfile.NamedTemporaryFile() as modelfile:
999-
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
1000-
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
1001-
modelfile.flush()
1002-
1003-
response = await client.create('dummy', path=modelfile.name)
1004-
assert response['status'] == 'success'
1005-
1006-
1007-
@pytest.mark.asyncio
1008-
async def test_async_client_create_modelfile(httpserver: HTTPServer):
1009-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
1010-
httpserver.expect_ordered_request(
1011-
'/api/create',
1012-
method='POST',
1013-
json={
1014-
'model': 'dummy',
1015-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
942+
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
1016943
'stream': False,
1017944
},
1018945
).respond_with_json({'status': 'success'})
1019946

1020947
client = AsyncClient(httpserver.url_for('/'))
1021948

1022949
with tempfile.NamedTemporaryFile() as blob:
1023-
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
950+
response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
1024951
assert response['status'] == 'success'
1025952

1026953

1027954
@pytest.mark.asyncio
1028-
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
1029-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
955+
async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
1030956
httpserver.expect_ordered_request(
1031957
'/api/create',
1032958
method='POST',
1033959
json={
1034960
'model': 'dummy',
1035-
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
1036-
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
1037-
{{.Prompt}} [/INST]"""
1038-
SYSTEM """
1039-
Use
1040-
multiline
1041-
strings.
1042-
"""
1043-
PARAMETER stop [INST]
1044-
PARAMETER stop [/INST]
1045-
PARAMETER stop <<SYS>>
1046-
PARAMETER stop <</SYS>>''',
961+
'quantize': 'q4_k_m',
962+
'from': 'mymodel',
963+
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
964+
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
965+
'license': 'this is my license',
966+
'system': '\nUse\nmultiline\nstrings.\n',
967+
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
968+
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
1047969
'stream': False,
1048970
},
1049971
).respond_with_json({'status': 'success'})
@@ -1053,22 +975,15 @@ async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
1053975
with tempfile.NamedTemporaryFile() as blob:
1054976
response = await client.create(
1055977
'dummy',
1056-
modelfile='\n'.join(
1057-
[
1058-
f'FROM {blob.name}',
1059-
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
1060-
'{{.Prompt}} [/INST]"""',
1061-
'SYSTEM """',
1062-
'Use',
1063-
'multiline',
1064-
'strings.',
1065-
'"""',
1066-
'PARAMETER stop [INST]',
1067-
'PARAMETER stop [/INST]',
1068-
'PARAMETER stop <<SYS>>',
1069-
'PARAMETER stop <</SYS>>',
1070-
]
1071-
),
978+
quantize='q4_k_m',
979+
from_='mymodel',
980+
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
981+
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
982+
license='this is my license',
983+
system='\nUse\nmultiline\nstrings.\n',
984+
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
985+
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
986+
stream=False,
1072987
)
1073988
assert response['status'] == 'success'
1074989

@@ -1080,14 +995,14 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
1080995
method='POST',
1081996
json={
1082997
'model': 'dummy',
1083-
'modelfile': 'FROM llama2',
998+
'from': 'llama2',
1084999
'stream': False,
10851000
},
10861001
).respond_with_json({'status': 'success'})
10871002

10881003
client = AsyncClient(httpserver.url_for('/'))
10891004

1090-
response = await client.create('dummy', modelfile='FROM llama2')
1005+
response = await client.create('dummy', from_='llama2')
10911006
assert response['status'] == 'success'
10921007

10931008

0 commit comments

Comments
 (0)