Skip to content

Commit ad3375c

Browse files
committed
Better protobuf serialization
1 parent 985ca77 commit ad3375c

File tree

4 files changed

+239
-111
lines changed

4 files changed

+239
-111
lines changed

guidance/_grammar.proto

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ syntax = "proto3";
22

33
package guidance;
44

5+
message Grammar {
6+
repeated StatelessFunction nodes = 1;
7+
}
8+
59
message Byte {
610
bytes byte = 1;
711
bool hidden = 2;
@@ -34,7 +38,7 @@ message Join {
3438
bool nullable = 1;
3539

3640
// Use a repeated field to store the list of values
37-
repeated StatelessFunction values = 2;
41+
repeated int32 values = 2;
3842

3943
string name = 3;
4044
bool hidden = 4;
@@ -47,7 +51,7 @@ message Select {
4751
bool nullable = 1;
4852

4953
// Use a repeated field to store the list of values
50-
repeated StatelessFunction values = 2;
54+
repeated int32 values = 2;
5155

5256
string name = 3;
5357
bool hidden = 4;
@@ -70,5 +74,6 @@ message StatelessFunction {
7074
Select select = 2;
7175
Byte byte = 3;
7276
ByteRange byte_range = 4;
77+
ModelVariable model_variable = 5;
7378
}
7479
}

guidance/_grammar.py

Lines changed: 102 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,65 @@ def gbnf_string(self):
149149
root_name = self._rec_gbnf_string(lines, used_names, names)
150150
lines.append("root ::= " + root_name)
151151
return "\n".join(lines)
152+
153+
def serialize(self):
154+
g = _grammar_pb2.Grammar()
155+
index_map = {}
156+
nodes = {}
157+
self._rec_serialize(index_map, nodes) # nodes is filled in (as is index_map)
158+
g.nodes.extend(list(nodes.values()))
159+
return g.SerializeToString()
160+
161+
def _rec_serialize(self, index_map, nodes):
162+
if self not in nodes:
163+
v = self._to_proto(index_map)
164+
node = _grammar_pb2.StatelessFunction()
165+
if isinstance(self, Byte):
166+
node.byte.CopyFrom(v)
167+
elif isinstance(self, ByteRange):
168+
node.byte_range.CopyFrom(v)
169+
elif isinstance(self, Select):
170+
node.select.CopyFrom(v)
171+
elif isinstance(self, Join):
172+
node.join.CopyFrom(v)
173+
elif isinstance(self, ModelVariable):
174+
node.model_variable.CopyFrom(v)
175+
else:
176+
raise Exception("Unknown node type")
177+
nodes[self] = node
178+
if hasattr(self, "values"):
179+
for value in self.values:
180+
value._rec_serialize(index_map, nodes)
181+
182+
@classmethod
183+
def deserialize(cls, serialized_grammar):
184+
g = _grammar_pb2.Grammar()
185+
g.ParseFromString(serialized_grammar)
186+
187+
# create the list of objects
188+
values = []
189+
for node in g.nodes:
190+
if node.HasField("byte"):
191+
node = Byte._from_proto(node.byte)
192+
elif node.HasField("byte_range"):
193+
node = ByteRange._from_proto(node.byte_range)
194+
elif node.HasField("select"):
195+
node = Select._from_proto(node.select)
196+
elif node.HasField("join"):
197+
node = Join._from_proto(node.join)
198+
elif node.HasField("model_variable"):
199+
node = ModelVariable._from_proto(node.model_variable)
200+
else:
201+
raise Exception("Unknown node type")
202+
values.append(node)
203+
204+
# fill in the values pointers now that we have the full list of objects
205+
for v in values:
206+
if hasattr(v, "values"):
207+
for i, index in enumerate(v.values):
208+
v.values[i] = values[index]
209+
210+
return values[0] # the first element in the root node of the grammar
152211

153212
class Terminal(StatelessFunction):
154213
def match_byte(self, byte):
@@ -193,7 +252,9 @@ def match_byte(self, byte):
193252
def nullable(self):
194253
return False
195254

196-
def _to_proto(self):
255+
def _to_proto(self, index_map):
256+
if self not in index_map:
257+
index_map[self] = len(index_map)
197258
data = _grammar_pb2.Byte()
198259
data.byte = self.byte
199260
data.hidden = self.hidden
@@ -202,15 +263,6 @@ def _to_proto(self):
202263
data.temperature = self.temperature
203264
return data
204265

205-
def serialize(self):
206-
return self._to_proto().SerializeToString()
207-
208-
@staticmethod
209-
def deserialize(data_bytes):
210-
data = _grammar_pb2.Byte()
211-
data.ParseFromString(data_bytes)
212-
return Byte._from_proto(data)
213-
214266
@staticmethod
215267
def _from_proto(data):
216268
out = Byte(data.byte)
@@ -258,17 +310,16 @@ def __repr__(self) -> str:
258310
def __len__(self):
259311
return 1
260312

261-
def _to_proto(self):
313+
def _to_proto(self, index_map):
314+
if self not in index_map:
315+
index_map[self] = len(index_map)
262316
data = _grammar_pb2.ByteRange()
263317
data.byte_range = self.byte_range
264318
data.hidden = self.hidden
265319
data.commit_point = self.commit_point
266320
data.capture_name = "" if self.capture_name is None else self.capture_name
267321
data.temperature = self.temperature
268322
return data
269-
270-
def serialize(self):
271-
return self._to_proto().SerializeToString()
272323

273324
@staticmethod
274325
def _from_proto(data):
@@ -278,12 +329,6 @@ def _from_proto(data):
278329
out.capture_name = None if data.capture_name == "" else data.capture_name
279330
out.temperature = data.temperature
280331
return out
281-
282-
@staticmethod
283-
def deserialize(data_bytes):
284-
data = _grammar_pb2.ByteRange()
285-
data.ParseFromString(data_bytes)
286-
return ByteRange._from_proto(data)
287332

288333
class Null():
289334
__slots__ = ("name", "hidden", "commit_point", "capture_name")
@@ -324,6 +369,24 @@ def __init__(self, name):
324369
self.capture_name = None
325370
self.nullable = False
326371

372+
def _to_proto(self, index_map):
373+
if self not in index_map:
374+
index_map[self] = len(index_map)
375+
data = _grammar_pb2.ModelVariable()
376+
data.hidden = self.hidden
377+
data.name = self.name
378+
data.commit_point = self.commit_point
379+
data.capture_name = "" if self.capture_name is None else self.capture_name
380+
return data
381+
382+
@staticmethod
383+
def _from_proto(data):
384+
out = ModelVariable(data.name)
385+
out.hidden = data.hidden
386+
out.commit_point = data.commit_point
387+
out.capture_name = None if data.capture_name == "" else data.capture_name
388+
return out
389+
327390
def replace_grammar_node(grammar, target, replacement):
328391
# Use a stack to keep track of the nodes to be visited
329392
stack = [grammar]
@@ -492,7 +555,7 @@ class Join(StatelessFunction):
492555

493556
def __init__(self, values, name=None, max_tokens=100000000) -> None:
494557
values = [string(v) if isinstance(v, (str, bytes)) else v for v in values] # wrap raw strings
495-
self.nullable = all(v.nullable for v in values)
558+
self.nullable = all(getattr(v, "nullable", False) for v in values)
496559
self.values = [v for v in values if not isinstance(v, Null)]
497560
self.name = name if name is not None else StatelessFunction._new_name()
498561
self.hidden = False
@@ -511,46 +574,26 @@ def __repr__(self, indent="", done=None):
511574
s += v.__repr__(indent, done)
512575
return s
513576

514-
def _to_proto(self):
577+
def _to_proto(self, index_map):
515578
data = _grammar_pb2.Join()
516579
data.nullable = self.nullable
580+
if self not in index_map:
581+
index_map[self] = len(index_map)
517582
for v in self.values:
518-
inner = _grammar_pb2.StatelessFunction()
519-
if isinstance(v, Byte):
520-
inner.byte.CopyFrom(v._to_proto())
521-
elif isinstance(v, ByteRange):
522-
inner.byte_range.CopyFrom(v._to_proto())
523-
elif isinstance(v, Join):
524-
inner.join.CopyFrom(v._to_proto())
525-
elif isinstance(v, Select):
526-
inner.select.CopyFrom(v._to_proto())
527-
data.values.append(inner)
583+
if v not in index_map:
584+
index_map[v] = len(index_map)
585+
data.values.append(index_map[v])
528586
data.name = self.name
529587
data.hidden = self.hidden
530588
data.commit_point = self.commit_point
531589
data.capture_name = "" if self.capture_name is None else self.capture_name
532590
data.max_tokens = self.max_tokens
533591
return data
534-
535-
def serialize(self):
536-
return self._to_proto().SerializeToString()
537592

538593
@staticmethod
539594
def _from_proto(data):
540-
values = []
541-
for v in data.values:
542-
if v.HasField("byte"):
543-
values.append(Byte._from_proto(v.byte))
544-
elif v.HasField("byte_range"):
545-
values.append(ByteRange._from_proto(v.byte_range))
546-
elif v.HasField("join"):
547-
values.append(Join._from_proto(v.join))
548-
elif v.HasField("select"):
549-
values.append(Select._from_proto(v.select))
550-
else:
551-
raise Exception("Unknown type of value")
552595
out = Join(
553-
values,
596+
data.values, # we put ints in that will be replaced later by the deserialize method
554597
name=data.name,
555598
max_tokens=data.max_tokens
556599
)
@@ -559,12 +602,6 @@ def _from_proto(data):
559602
out.commit_point = data.commit_point
560603
out.capture_name = None if data.capture_name == "" else data.capture_name
561604
return out
562-
563-
@staticmethod
564-
def deserialize(data_bytes):
565-
data = _grammar_pb2.Join()
566-
data.ParseFromString(data_bytes)
567-
return Join._from_proto(data)
568605

569606

570607
class Select(StatelessFunction):
@@ -585,7 +622,7 @@ def values(self):
585622
@values.setter
586623
def values(self, vals):
587624
self._values = [string(v) if isinstance(v, (str, bytes)) else v for v in vals]
588-
self.nullable = any(v.nullable for v in self._values)
625+
self.nullable = any(getattr(v, "nullable", False) for v in self._values)
589626
self._values = [v for v in self._values if not isinstance(v, Null)]
590627

591628
def __repr__(self, indent="", done=None):
@@ -599,47 +636,28 @@ def __repr__(self, indent="", done=None):
599636
s += v.__repr__(indent, done)
600637
return s
601638

602-
def _to_proto(self):
603-
data = _grammar_pb2.Join()
639+
def _to_proto(self, index_map):
640+
data = _grammar_pb2.Select()
604641
data.nullable = self.nullable
642+
if self not in index_map:
643+
index_map[self] = len(index_map)
605644
for v in self.values:
606-
inner = _grammar_pb2.StatelessFunction()
607-
if isinstance(v, Byte):
608-
inner.byte.CopyFrom(v._to_proto())
609-
elif isinstance(v, ByteRange):
610-
inner.byte_range.CopyFrom(v._to_proto())
611-
elif isinstance(v, Join):
612-
inner.join.CopyFrom(v._to_proto())
613-
elif isinstance(v, Select):
614-
inner.select.CopyFrom(v._to_proto())
615-
data.values.append(inner)
645+
if v not in index_map:
646+
index_map[v] = len(index_map)
647+
data.values.append(index_map[v])
616648
data.name = self.name
617649
data.hidden = self.hidden
618650
data.commit_point = self.commit_point
619651
data.capture_name = "" if self.capture_name is None else self.capture_name
620652
data.max_tokens = self.max_tokens
621653
data.recursive = self.recursive
654+
622655
return data
623-
624-
def serialize(self):
625-
return self._to_proto().SerializeToString()
626656

627657
@staticmethod
628658
def _from_proto(data):
629-
values = []
630-
for v in data.values:
631-
if v.HasField("byte"):
632-
values.append(Byte._from_proto(v.byte))
633-
elif v.HasField("byte_range"):
634-
values.append(ByteRange._from_proto(v.byte_range))
635-
elif v.HasField("join"):
636-
values.append(Join._from_proto(v.join))
637-
elif v.HasField("select"):
638-
values.append(Select._from_proto(v.select))
639-
else:
640-
raise Exception("Unknown type of value")
641-
out = Join(
642-
values,
659+
out = Select(
660+
data.values, # we put ints in that will be replaced later by the deserialize method
643661
name=data.name,
644662
max_tokens=data.max_tokens
645663
)
@@ -649,12 +667,6 @@ def _from_proto(data):
649667
out.capture_name = None if data.capture_name == "" else data.capture_name
650668
out.recursive = data.recursive
651669
return out
652-
653-
@staticmethod
654-
def deserialize(data_bytes):
655-
data = _grammar_pb2.Join()
656-
data.ParseFromString(data_bytes)
657-
return Join._from_proto(data)
658670

659671
def string(value):
660672
if isinstance(value, str):

0 commit comments

Comments
 (0)