Skip to content

Commit 4dec73e

Browse files
committed
add unit tests
1 parent 4f9fb88 commit 4dec73e

File tree

3 files changed

+28
-111
lines changed

3 files changed

+28
-111
lines changed

ollama/_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,6 @@ def create(
526526
527527
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
528528
"""
529-
#if from_ == None and files == None:
530-
# raise RequestError('neither ''from'' or ''files'' was specified')
531-
532529
return self._request(
533530
ProgressResponse,
534531
'POST',
@@ -541,6 +538,7 @@ def create(
541538
files=files,
542539
adapters=adapters,
543540
license=license,
541+
template=template,
544542
system=system,
545543
parameters=parameters,
546544
messages=messages,

tests/test_client.py

Lines changed: 23 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -536,51 +536,6 @@ def generate():
536536
assert part['status'] == next(it)
537537

538538

539-
def test_client_create_path(httpserver: HTTPServer):
540-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
541-
httpserver.expect_ordered_request(
542-
'/api/create',
543-
method='POST',
544-
json={
545-
'model': 'dummy',
546-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
547-
'stream': False,
548-
},
549-
).respond_with_json({'status': 'success'})
550-
551-
client = Client(httpserver.url_for('/'))
552-
553-
with tempfile.NamedTemporaryFile() as modelfile:
554-
with tempfile.NamedTemporaryFile() as blob:
555-
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
556-
modelfile.flush()
557-
558-
response = client.create('dummy', path=modelfile.name)
559-
assert response['status'] == 'success'
560-
561-
562-
def test_client_create_path_relative(httpserver: HTTPServer):
563-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
564-
httpserver.expect_ordered_request(
565-
'/api/create',
566-
method='POST',
567-
json={
568-
'model': 'dummy',
569-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
570-
'stream': False,
571-
},
572-
).respond_with_json({'status': 'success'})
573-
574-
client = Client(httpserver.url_for('/'))
575-
576-
with tempfile.NamedTemporaryFile() as modelfile:
577-
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
578-
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
579-
modelfile.flush()
580-
581-
response = client.create('dummy', path=modelfile.name)
582-
assert response['status'] == 'success'
583-
584539

585540
@pytest.fixture
586541
def userhomedir():
@@ -591,67 +546,38 @@ def userhomedir():
591546
os.environ['HOME'] = home
592547

593548

594-
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
595-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
596-
httpserver.expect_ordered_request(
597-
'/api/create',
598-
method='POST',
599-
json={
600-
'model': 'dummy',
601-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
602-
'stream': False,
603-
},
604-
).respond_with_json({'status': 'success'})
605-
606-
client = Client(httpserver.url_for('/'))
607-
608-
with tempfile.NamedTemporaryFile() as modelfile:
609-
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
610-
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
611-
modelfile.flush()
612-
613-
response = client.create('dummy', path=modelfile.name)
614-
assert response['status'] == 'success'
615-
616-
617-
def test_client_create_modelfile(httpserver: HTTPServer):
618-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
549+
def test_client_create_with_blob(httpserver: HTTPServer):
619550
httpserver.expect_ordered_request(
620551
'/api/create',
621552
method='POST',
622553
json={
623554
'model': 'dummy',
624-
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
555+
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
625556
'stream': False,
626557
},
627558
).respond_with_json({'status': 'success'})
628559

629560
client = Client(httpserver.url_for('/'))
630561

631562
with tempfile.NamedTemporaryFile() as blob:
632-
response = client.create('dummy', modelfile=f'FROM {blob.name}')
563+
response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
633564
assert response['status'] == 'success'
634565

635566

636-
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
637-
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
567+
def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
638568
httpserver.expect_ordered_request(
639569
'/api/create',
640570
method='POST',
641571
json={
642572
'model': 'dummy',
643-
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
644-
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
645-
{{.Prompt}} [/INST]"""
646-
SYSTEM """
647-
Use
648-
multiline
649-
strings.
650-
"""
651-
PARAMETER stop [INST]
652-
PARAMETER stop [/INST]
653-
PARAMETER stop <<SYS>>
654-
PARAMETER stop <</SYS>>''',
573+
'quantize': 'q4_k_m',
574+
'from': 'mymodel',
575+
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
576+
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
577+
'license': 'this is my license',
578+
'system': '\nUse\nmultiline\nstrings.\n',
579+
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
580+
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
655581
'stream': False,
656582
},
657583
).respond_with_json({'status': 'success'})
@@ -661,22 +587,15 @@ def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
661587
with tempfile.NamedTemporaryFile() as blob:
662588
response = client.create(
663589
'dummy',
664-
modelfile='\n'.join(
665-
[
666-
f'FROM {blob.name}',
667-
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
668-
'{{.Prompt}} [/INST]"""',
669-
'SYSTEM """',
670-
'Use',
671-
'multiline',
672-
'strings.',
673-
'"""',
674-
'PARAMETER stop [INST]',
675-
'PARAMETER stop [/INST]',
676-
'PARAMETER stop <<SYS>>',
677-
'PARAMETER stop <</SYS>>',
678-
]
679-
),
590+
quantize='q4_k_m',
591+
from_='mymodel',
592+
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
593+
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
594+
license='this is my license',
595+
system='\nUse\nmultiline\nstrings.\n',
596+
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
597+
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
598+
stream=False,
680599
)
681600
assert response['status'] == 'success'
682601

@@ -687,14 +606,14 @@ def test_client_create_from_library(httpserver: HTTPServer):
687606
method='POST',
688607
json={
689608
'model': 'dummy',
690-
'modelfile': 'FROM llama2',
609+
'from': 'llama2',
691610
'stream': False,
692611
},
693612
).respond_with_json({'status': 'success'})
694613

695614
client = Client(httpserver.url_for('/'))
696615

697-
response = client.create('dummy', modelfile='FROM llama2')
616+
response = client.create('dummy', from_='llama2')
698617
assert response['status'] == 'success'
699618

700619

tests/test_type_serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ def test_create_request_serialization():
6868
system="test system",
6969
parameters={"param1": "value1"}
7070
)
71-
71+
7272
serialized = request.model_dump()
73-
assert serialized["from"] == "base-model"
74-
assert "from_" not in serialized
73+
assert serialized["from"] == "base-model"
74+
assert "from_" not in serialized
7575
assert serialized["quantize"] == "q4_0"
7676
assert serialized["files"] == {"file1": "content1"}
7777
assert serialized["adapters"] == {"adapter1": "content1"}
@@ -89,7 +89,7 @@ def test_create_request_serialization_exclude_none_true():
8989
quantize=None
9090
)
9191
serialized = request.model_dump(exclude_none=True)
92-
assert serialized == {"model": "test-model"}
92+
assert serialized == {"model": "test-model"}
9393
assert "from" not in serialized
9494
assert "from_" not in serialized
9595
assert "quantize" not in serialized

0 commit comments

Comments
 (0)