|
1 | 1 | import numpy as np |
2 | 2 |
|
3 | 3 | from ._model import Tokenizer, Engine, Model, Chat |
| 4 | +from ._remote import RemoteEngine |
4 | 5 |
|
5 | 6 | class MockEngine(Engine): |
6 | 7 | def __init__(self, tokenizer, byte_patterns, compute_log_probs): |
@@ -54,16 +55,23 @@ def _get_next_tokens(self, byte_string): |
54 | 55 | yield i |
55 | 56 |
|
56 | 57 | class Mock(Model): |
57 | | - def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False): |
| 58 | + def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False, **kwargs): |
58 | 59 | '''Build a new Mock model object that represents a model in a given state.''' |
59 | | - tokenizer = Tokenizer( |
60 | | - # our tokens are all bytes and all lowercase letter pairs |
61 | | - [b"<s>"] + [bytes([i,j]) for i in range(ord('a'), ord('z')) for j in range(ord('a'), ord('z'))] + [bytes([i]) for i in range(256)], |
62 | | - 0, |
63 | | - 0 |
64 | | - ) |
| 60 | + |
| 61 | + if isinstance(byte_patterns, str) and byte_patterns.startswith("http"): |
| 62 | + engine = RemoteEngine(byte_patterns, **kwargs) |
| 63 | + else: |
| 64 | + tokenizer = Tokenizer( |
| 65 | + # our tokens are all bytes and all lowercase letter pairs |
| 66 | + [b"<s>"] + [bytes([i,j]) for i in range(ord('a'), ord('z')) for j in range(ord('a'), ord('z'))] + [bytes([i]) for i in range(256)], |
| 67 | + 0, |
| 68 | + 0 |
| 69 | + ) |
| 70 | + engine = MockEngine(tokenizer, byte_patterns, compute_log_probs) |
| 71 | + |
| 72 | + |
65 | 73 | super().__init__( |
66 | | - MockEngine(tokenizer, byte_patterns, compute_log_probs), |
| 74 | + engine, |
67 | 75 | echo=echo |
68 | 76 | ) |
69 | 77 |
|
|
0 commit comments