Skip to content

Commit 3772583

Browse files
committed
First working client server setup round trip
1 parent ad3375c commit 3772583

File tree

6 files changed

+293
-78
lines changed

6 files changed

+293
-78
lines changed

guidance/_grammar.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import types
66
import re
7-
from . import _grammar_pb2
7+
from . import _serialization_pb2
88

99
tag_start = "{{G|"
1010
tag_end = "|G}}"
@@ -151,17 +151,25 @@ def gbnf_string(self):
151151
return "\n".join(lines)
152152

153153
def serialize(self):
154-
g = _grammar_pb2.Grammar()
154+
g = _serialization_pb2.Grammar()
155155
index_map = {}
156156
nodes = {}
157+
self._rec_create_index_map(index_map) # gives all the nodes an index
157158
self._rec_serialize(index_map, nodes) # nodes is filled in (as is index_map)
158159
g.nodes.extend(list(nodes.values()))
159160
return g.SerializeToString()
160161

162+
def _rec_create_index_map(self, index_map):
163+
if self not in index_map:
164+
index_map[self] = len(index_map)
165+
if hasattr(self, "values"):
166+
for value in self.values:
167+
value._rec_create_index_map(index_map)
168+
161169
def _rec_serialize(self, index_map, nodes):
162170
if self not in nodes:
163171
v = self._to_proto(index_map)
164-
node = _grammar_pb2.StatelessFunction()
172+
node = _serialization_pb2.StatelessFunction()
165173
if isinstance(self, Byte):
166174
node.byte.CopyFrom(v)
167175
elif isinstance(self, ByteRange):
@@ -181,7 +189,7 @@ def _rec_serialize(self, index_map, nodes):
181189

182190
@classmethod
183191
def deserialize(cls, serialized_grammar):
184-
g = _grammar_pb2.Grammar()
192+
g = _serialization_pb2.Grammar()
185193
g.ParseFromString(serialized_grammar)
186194

187195
# create the list of objects
@@ -253,9 +261,7 @@ def nullable(self):
253261
return False
254262

255263
def _to_proto(self, index_map):
256-
if self not in index_map:
257-
index_map[self] = len(index_map)
258-
data = _grammar_pb2.Byte()
264+
data = _serialization_pb2.Byte()
259265
data.byte = self.byte
260266
data.hidden = self.hidden
261267
data.commit_point = self.commit_point
@@ -311,9 +317,7 @@ def __len__(self):
311317
return 1
312318

313319
def _to_proto(self, index_map):
314-
if self not in index_map:
315-
index_map[self] = len(index_map)
316-
data = _grammar_pb2.ByteRange()
320+
data = _serialization_pb2.ByteRange()
317321
data.byte_range = self.byte_range
318322
data.hidden = self.hidden
319323
data.commit_point = self.commit_point
@@ -370,9 +374,7 @@ def __init__(self, name):
370374
self.nullable = False
371375

372376
def _to_proto(self, index_map):
373-
if self not in index_map:
374-
index_map[self] = len(index_map)
375-
data = _grammar_pb2.ModelVariable()
377+
data = _serialization_pb2.ModelVariable()
376378
data.hidden = self.hidden
377379
data.name = self.name
378380
data.commit_point = self.commit_point
@@ -575,13 +577,9 @@ def __repr__(self, indent="", done=None):
575577
return s
576578

577579
def _to_proto(self, index_map):
578-
data = _grammar_pb2.Join()
580+
data = _serialization_pb2.Join()
579581
data.nullable = self.nullable
580-
if self not in index_map:
581-
index_map[self] = len(index_map)
582582
for v in self.values:
583-
if v not in index_map:
584-
index_map[v] = len(index_map)
585583
data.values.append(index_map[v])
586584
data.name = self.name
587585
data.hidden = self.hidden
@@ -637,13 +635,9 @@ def __repr__(self, indent="", done=None):
637635
return s
638636

639637
def _to_proto(self, index_map):
640-
data = _grammar_pb2.Select()
638+
data = _serialization_pb2.Select()
641639
data.nullable = self.nullable
642-
if self not in index_map:
643-
index_map[self] = len(index_map)
644640
for v in self.values:
645-
if v not in index_map:
646-
index_map[v] = len(index_map)
647641
data.values.append(index_map[v])
648642
data.name = self.name
649643
data.hidden = self.hidden

guidance/_grammar.proto renamed to guidance/_serialization.proto

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,30 @@ message Grammar {
66
repeated StatelessFunction nodes = 1;
77
}
88

9+
message EngineCallResponse {
10+
bytes new_bytes = 1;
11+
bool is_generated = 2;
12+
float new_bytes_prob = 3;
13+
map<string, string> capture_groups = 4;
14+
map<string, float> capture_group_log_probs = 5;
15+
int32 new_token_count = 6;
16+
}
17+
918
message Byte {
1019
bytes byte = 1;
1120
bool hidden = 2;
1221
bool commit_point = 3;
1322
bool nullable = 4;
1423
string capture_name = 5;
15-
double temperature = 6;
24+
float temperature = 6;
1625
}
1726

1827
message ByteRange {
1928
bytes byte_range = 1;
2029
bool hidden = 3;
2130
bool commit_point = 4;
2231
string capture_name = 5;
23-
double temperature = 6;
32+
float temperature = 6;
2433
}
2534

2635
message Null {

0 commit comments

Comments
 (0)