Skip to content

Commit cf3ab80

Browse files
authored
Merge pull request #40 from ollama/mxyng/fix-parse-modelfile
fix parse modelfile
2 parents e201181 + 8e5d431 commit cf3ab80

File tree

2 files changed

+117
-14
lines changed

2 files changed

+117
-14
lines changed

ollama/_client.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,16 @@ def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
259259
out = io.StringIO()
260260
for line in io.StringIO(modelfile):
261261
command, _, args = line.partition(' ')
262-
if command.upper() in ['FROM', 'ADAPTER']:
263-
path = Path(args.strip()).expanduser()
264-
path = path if path.is_absolute() else base / path
265-
if path.exists():
266-
args = f'@{self._create_blob(path)}'
262+
if command.upper() not in ['FROM', 'ADAPTER']:
263+
print(line, end='', file=out)
264+
continue
265+
266+
path = Path(args.strip()).expanduser()
267+
path = path if path.is_absolute() else base / path
268+
if path.exists():
269+
args = f'@{self._create_blob(path)}\n'
270+
print(command, args, end='', file=out)
267271

268-
print(command, args, file=out)
269272
return out.getvalue()
270273

271274
def _create_blob(self, path: Union[str, Path]) -> str:
@@ -527,13 +530,16 @@ async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) ->
527530
out = io.StringIO()
528531
for line in io.StringIO(modelfile):
529532
command, _, args = line.partition(' ')
530-
if command.upper() in ['FROM', 'ADAPTER']:
531-
path = Path(args).expanduser()
532-
path = path if path.is_absolute() else base / path
533-
if path.exists():
534-
args = f'@{await self._create_blob(path)}'
533+
if command.upper() not in ['FROM', 'ADAPTER']:
534+
print(line, end='', file=out)
535+
continue
536+
537+
path = Path(args.strip()).expanduser()
538+
path = path if path.is_absolute() else base / path
539+
if path.exists():
540+
args = f'@{await self._create_blob(path)}\n'
541+
print(command, args, end='', file=out)
535542

536-
print(command, args, file=out)
537543
return out.getvalue()
538544

539545
async def _create_blob(self, path: Union[str, Path]) -> str:

tests/test_client.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,61 @@ def test_client_create_modelfile(httpserver: HTTPServer):
416416
assert isinstance(response, dict)
417417

418418

419+
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
420+
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
421+
httpserver.expect_ordered_request(
422+
'/api/create',
423+
method='POST',
424+
json={
425+
'name': 'dummy',
426+
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
427+
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
428+
{{.Prompt}} [/INST]"""
429+
SYSTEM """
430+
Use
431+
multiline
432+
strings.
433+
"""
434+
PARAMETER stop [INST]
435+
PARAMETER stop [/INST]
436+
PARAMETER stop <<SYS>>
437+
PARAMETER stop <</SYS>>''',
438+
'stream': False,
439+
},
440+
).respond_with_json({})
441+
442+
client = Client(httpserver.url_for('/'))
443+
444+
with tempfile.NamedTemporaryFile() as blob:
445+
response = client.create(
446+
'dummy',
447+
modelfile='\n'.join(
448+
[
449+
f'FROM {blob.name}',
450+
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
451+
'{{.Prompt}} [/INST]"""',
452+
'SYSTEM """',
453+
'Use',
454+
'multiline',
455+
'strings.',
456+
'"""',
457+
'PARAMETER stop [INST]',
458+
'PARAMETER stop [/INST]',
459+
'PARAMETER stop <<SYS>>',
460+
'PARAMETER stop <</SYS>>',
461+
]
462+
),
463+
)
464+
assert isinstance(response, dict)
465+
466+
419467
def test_client_create_from_library(httpserver: HTTPServer):
420468
httpserver.expect_ordered_request(
421469
'/api/create',
422470
method='POST',
423471
json={
424472
'name': 'dummy',
425-
'modelfile': 'FROM llama2\n',
473+
'modelfile': 'FROM llama2',
426474
'stream': False,
427475
},
428476
).respond_with_json({})
@@ -820,14 +868,63 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
820868
assert isinstance(response, dict)
821869

822870

871+
@pytest.mark.asyncio
872+
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
873+
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
874+
httpserver.expect_ordered_request(
875+
'/api/create',
876+
method='POST',
877+
json={
878+
'name': 'dummy',
879+
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
880+
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
881+
{{.Prompt}} [/INST]"""
882+
SYSTEM """
883+
Use
884+
multiline
885+
strings.
886+
"""
887+
PARAMETER stop [INST]
888+
PARAMETER stop [/INST]
889+
PARAMETER stop <<SYS>>
890+
PARAMETER stop <</SYS>>''',
891+
'stream': False,
892+
},
893+
).respond_with_json({})
894+
895+
client = AsyncClient(httpserver.url_for('/'))
896+
897+
with tempfile.NamedTemporaryFile() as blob:
898+
response = await client.create(
899+
'dummy',
900+
modelfile='\n'.join(
901+
[
902+
f'FROM {blob.name}',
903+
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
904+
'{{.Prompt}} [/INST]"""',
905+
'SYSTEM """',
906+
'Use',
907+
'multiline',
908+
'strings.',
909+
'"""',
910+
'PARAMETER stop [INST]',
911+
'PARAMETER stop [/INST]',
912+
'PARAMETER stop <<SYS>>',
913+
'PARAMETER stop <</SYS>>',
914+
]
915+
),
916+
)
917+
assert isinstance(response, dict)
918+
919+
823920
@pytest.mark.asyncio
824921
async def test_async_client_create_from_library(httpserver: HTTPServer):
825922
httpserver.expect_ordered_request(
826923
'/api/create',
827924
method='POST',
828925
json={
829926
'name': 'dummy',
830-
'modelfile': 'FROM llama2\n',
927+
'modelfile': 'FROM llama2',
831928
'stream': False,
832929
},
833930
).respond_with_json({})

0 commit comments

Comments
 (0)