Skip to content

Commit 4312c06

Browse files
committed
Add remote support to Mock
1 parent 12c721d commit 4312c06

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

guidance/models/_mock.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from ._model import Tokenizer, Engine, Model, Chat
4+
from ._remote import RemoteEngine
45

56
class MockEngine(Engine):
67
def __init__(self, tokenizer, byte_patterns, compute_log_probs):
@@ -54,16 +55,23 @@ def _get_next_tokens(self, byte_string):
5455
yield i
5556

5657
class Mock(Model):
57-
def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False):
58+
def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False, **kwargs):
5859
'''Build a new Mock model object that represents a model in a given state.'''
59-
tokenizer = Tokenizer(
60-
# our tokens are all bytes and all lowercase letter pairs
61-
[b"<s>"] + [bytes([i,j]) for i in range(ord('a'), ord('z')) for j in range(ord('a'), ord('z'))] + [bytes([i]) for i in range(256)],
62-
0,
63-
0
64-
)
60+
61+
if isinstance(byte_patterns, str) and byte_patterns.startswith("http"):
62+
engine = RemoteEngine(byte_patterns, **kwargs)
63+
else:
64+
tokenizer = Tokenizer(
65+
# our tokens are all bytes and all lowercase letter pairs
66+
[b"<s>"] + [bytes([i,j]) for i in range(ord('a'), ord('z')) for j in range(ord('a'), ord('z'))] + [bytes([i]) for i in range(256)],
67+
0,
68+
0
69+
)
70+
engine = MockEngine(tokenizer, byte_patterns, compute_log_probs)
71+
72+
6573
super().__init__(
66-
MockEngine(tokenizer, byte_patterns, compute_log_probs),
74+
engine,
6775
echo=echo
6876
)
6977

tests/test_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def running_server():
1818
yield p
1919
p.terminate()
2020

21-
def test_remote_llama_cpp_gen(running_server):
21+
def test_remote_mock_gen(running_server):
2222
from guidance import models, gen
2323

2424
m = models.Mock("http://localhost:8392", api_key="SDFSDF")

0 commit comments

Comments
 (0)