@@ -149,6 +149,65 @@ def gbnf_string(self):
149149 root_name = self ._rec_gbnf_string (lines , used_names , names )
150150 lines .append ("root ::= " + root_name )
151151 return "\n " .join (lines )
152+
153+ def serialize (self ):
154+ g = _grammar_pb2 .Grammar ()
155+ index_map = {}
156+ nodes = {}
157+ self ._rec_serialize (index_map , nodes ) # nodes is filled in (as is index_map)
158+ g .nodes .extend (list (nodes .values ()))
159+ return g .SerializeToString ()
160+
161+ def _rec_serialize (self , index_map , nodes ):
162+ if self not in nodes :
163+ v = self ._to_proto (index_map )
164+ node = _grammar_pb2 .StatelessFunction ()
165+ if isinstance (self , Byte ):
166+ node .byte .CopyFrom (v )
167+ elif isinstance (self , ByteRange ):
168+ node .byte_range .CopyFrom (v )
169+ elif isinstance (self , Select ):
170+ node .select .CopyFrom (v )
171+ elif isinstance (self , Join ):
172+ node .join .CopyFrom (v )
173+ elif isinstance (self , ModelVariable ):
174+ node .model_variable .CopyFrom (v )
175+ else :
176+ raise Exception ("Unknown node type" )
177+ nodes [self ] = node
178+ if hasattr (self , "values" ):
179+ for value in self .values :
180+ value ._rec_serialize (index_map , nodes )
181+
182+ @classmethod
183+ def deserialize (cls , serialized_grammar ):
184+ g = _grammar_pb2 .Grammar ()
185+ g .ParseFromString (serialized_grammar )
186+
187+ # create the list of objects
188+ values = []
189+ for node in g .nodes :
190+ if node .HasField ("byte" ):
191+ node = Byte ._from_proto (node .byte )
192+ elif node .HasField ("byte_range" ):
193+ node = ByteRange ._from_proto (node .byte_range )
194+ elif node .HasField ("select" ):
195+ node = Select ._from_proto (node .select )
196+ elif node .HasField ("join" ):
197+ node = Join ._from_proto (node .join )
198+ elif node .HasField ("model_variable" ):
199+ node = ModelVariable ._from_proto (node .model_variable )
200+ else :
201+ raise Exception ("Unknown node type" )
202+ values .append (node )
203+
204+ # fill in the values pointers now that we have the full list of objects
205+ for v in values :
206+ if hasattr (v , "values" ):
207+ for i , index in enumerate (v .values ):
208+ v .values [i ] = values [index ]
209+
210+ return values [0 ] # the first element in the root node of the grammar
152211
153212class Terminal (StatelessFunction ):
154213 def match_byte (self , byte ):
@@ -193,7 +252,9 @@ def match_byte(self, byte):
193252 def nullable (self ):
194253 return False
195254
196- def _to_proto (self ):
255+ def _to_proto (self , index_map ):
256+ if self not in index_map :
257+ index_map [self ] = len (index_map )
197258 data = _grammar_pb2 .Byte ()
198259 data .byte = self .byte
199260 data .hidden = self .hidden
@@ -202,15 +263,6 @@ def _to_proto(self):
202263 data .temperature = self .temperature
203264 return data
204265
205- def serialize (self ):
206- return self ._to_proto ().SerializeToString ()
207-
208- @staticmethod
209- def deserialize (data_bytes ):
210- data = _grammar_pb2 .Byte ()
211- data .ParseFromString (data_bytes )
212- return Byte ._from_proto (data )
213-
214266 @staticmethod
215267 def _from_proto (data ):
216268 out = Byte (data .byte )
@@ -258,17 +310,16 @@ def __repr__(self) -> str:
258310 def __len__ (self ):
259311 return 1
260312
261- def _to_proto (self ):
313+ def _to_proto (self , index_map ):
314+ if self not in index_map :
315+ index_map [self ] = len (index_map )
262316 data = _grammar_pb2 .ByteRange ()
263317 data .byte_range = self .byte_range
264318 data .hidden = self .hidden
265319 data .commit_point = self .commit_point
266320 data .capture_name = "" if self .capture_name is None else self .capture_name
267321 data .temperature = self .temperature
268322 return data
269-
270- def serialize (self ):
271- return self ._to_proto ().SerializeToString ()
272323
273324 @staticmethod
274325 def _from_proto (data ):
@@ -278,12 +329,6 @@ def _from_proto(data):
278329 out .capture_name = None if data .capture_name == "" else data .capture_name
279330 out .temperature = data .temperature
280331 return out
281-
282- @staticmethod
283- def deserialize (data_bytes ):
284- data = _grammar_pb2 .ByteRange ()
285- data .ParseFromString (data_bytes )
286- return ByteRange ._from_proto (data )
287332
288333class Null ():
289334 __slots__ = ("name" , "hidden" , "commit_point" , "capture_name" )
@@ -324,6 +369,24 @@ def __init__(self, name):
324369 self .capture_name = None
325370 self .nullable = False
326371
372+ def _to_proto (self , index_map ):
373+ if self not in index_map :
374+ index_map [self ] = len (index_map )
375+ data = _grammar_pb2 .ModelVariable ()
376+ data .hidden = self .hidden
377+ data .name = self .name
378+ data .commit_point = self .commit_point
379+ data .capture_name = "" if self .capture_name is None else self .capture_name
380+ return data
381+
382+ @staticmethod
383+ def _from_proto (data ):
384+ out = ModelVariable (data .name )
385+ out .hidden = data .hidden
386+ out .commit_point = data .commit_point
387+ out .capture_name = None if data .capture_name == "" else data .capture_name
388+ return out
389+
327390def replace_grammar_node (grammar , target , replacement ):
328391 # Use a stack to keep track of the nodes to be visited
329392 stack = [grammar ]
@@ -492,7 +555,7 @@ class Join(StatelessFunction):
492555
493556 def __init__ (self , values , name = None , max_tokens = 100000000 ) -> None :
494557 values = [string (v ) if isinstance (v , (str , bytes )) else v for v in values ] # wrap raw strings
495- self .nullable = all (v . nullable for v in values )
558+ self .nullable = all (getattr ( v , " nullable" , False ) for v in values )
496559 self .values = [v for v in values if not isinstance (v , Null )]
497560 self .name = name if name is not None else StatelessFunction ._new_name ()
498561 self .hidden = False
@@ -511,46 +574,26 @@ def __repr__(self, indent="", done=None):
511574 s += v .__repr__ (indent , done )
512575 return s
513576
514- def _to_proto (self ):
577+ def _to_proto (self , index_map ):
515578 data = _grammar_pb2 .Join ()
516579 data .nullable = self .nullable
580+ if self not in index_map :
581+ index_map [self ] = len (index_map )
517582 for v in self .values :
518- inner = _grammar_pb2 .StatelessFunction ()
519- if isinstance (v , Byte ):
520- inner .byte .CopyFrom (v ._to_proto ())
521- elif isinstance (v , ByteRange ):
522- inner .byte_range .CopyFrom (v ._to_proto ())
523- elif isinstance (v , Join ):
524- inner .join .CopyFrom (v ._to_proto ())
525- elif isinstance (v , Select ):
526- inner .select .CopyFrom (v ._to_proto ())
527- data .values .append (inner )
583+ if v not in index_map :
584+ index_map [v ] = len (index_map )
585+ data .values .append (index_map [v ])
528586 data .name = self .name
529587 data .hidden = self .hidden
530588 data .commit_point = self .commit_point
531589 data .capture_name = "" if self .capture_name is None else self .capture_name
532590 data .max_tokens = self .max_tokens
533591 return data
534-
535- def serialize (self ):
536- return self ._to_proto ().SerializeToString ()
537592
538593 @staticmethod
539594 def _from_proto (data ):
540- values = []
541- for v in data .values :
542- if v .HasField ("byte" ):
543- values .append (Byte ._from_proto (v .byte ))
544- elif v .HasField ("byte_range" ):
545- values .append (ByteRange ._from_proto (v .byte_range ))
546- elif v .HasField ("join" ):
547- values .append (Join ._from_proto (v .join ))
548- elif v .HasField ("select" ):
549- values .append (Select ._from_proto (v .select ))
550- else :
551- raise Exception ("Unknown type of value" )
552595 out = Join (
553- values ,
596+ data . values , # we put ints in that will be replaced later by the deserialize method
554597 name = data .name ,
555598 max_tokens = data .max_tokens
556599 )
@@ -559,12 +602,6 @@ def _from_proto(data):
559602 out .commit_point = data .commit_point
560603 out .capture_name = None if data .capture_name == "" else data .capture_name
561604 return out
562-
563- @staticmethod
564- def deserialize (data_bytes ):
565- data = _grammar_pb2 .Join ()
566- data .ParseFromString (data_bytes )
567- return Join ._from_proto (data )
568605
569606
570607class Select (StatelessFunction ):
@@ -585,7 +622,7 @@ def values(self):
585622 @values .setter
586623 def values (self , vals ):
587624 self ._values = [string (v ) if isinstance (v , (str , bytes )) else v for v in vals ]
588- self .nullable = any (v . nullable for v in self ._values )
625+ self .nullable = any (getattr ( v , " nullable" , False ) for v in self ._values )
589626 self ._values = [v for v in self ._values if not isinstance (v , Null )]
590627
591628 def __repr__ (self , indent = "" , done = None ):
@@ -599,47 +636,28 @@ def __repr__(self, indent="", done=None):
599636 s += v .__repr__ (indent , done )
600637 return s
601638
602- def _to_proto (self ):
603- data = _grammar_pb2 .Join ()
639+ def _to_proto (self , index_map ):
640+ data = _grammar_pb2 .Select ()
604641 data .nullable = self .nullable
642+ if self not in index_map :
643+ index_map [self ] = len (index_map )
605644 for v in self .values :
606- inner = _grammar_pb2 .StatelessFunction ()
607- if isinstance (v , Byte ):
608- inner .byte .CopyFrom (v ._to_proto ())
609- elif isinstance (v , ByteRange ):
610- inner .byte_range .CopyFrom (v ._to_proto ())
611- elif isinstance (v , Join ):
612- inner .join .CopyFrom (v ._to_proto ())
613- elif isinstance (v , Select ):
614- inner .select .CopyFrom (v ._to_proto ())
615- data .values .append (inner )
645+ if v not in index_map :
646+ index_map [v ] = len (index_map )
647+ data .values .append (index_map [v ])
616648 data .name = self .name
617649 data .hidden = self .hidden
618650 data .commit_point = self .commit_point
619651 data .capture_name = "" if self .capture_name is None else self .capture_name
620652 data .max_tokens = self .max_tokens
621653 data .recursive = self .recursive
654+
622655 return data
623-
624- def serialize (self ):
625- return self ._to_proto ().SerializeToString ()
626656
627657 @staticmethod
628658 def _from_proto (data ):
629- values = []
630- for v in data .values :
631- if v .HasField ("byte" ):
632- values .append (Byte ._from_proto (v .byte ))
633- elif v .HasField ("byte_range" ):
634- values .append (ByteRange ._from_proto (v .byte_range ))
635- elif v .HasField ("join" ):
636- values .append (Join ._from_proto (v .join ))
637- elif v .HasField ("select" ):
638- values .append (Select ._from_proto (v .select ))
639- else :
640- raise Exception ("Unknown type of value" )
641- out = Join (
642- values ,
659+ out = Select (
660+ data .values , # we put ints in that will be replaced later by the deserialize method
643661 name = data .name ,
644662 max_tokens = data .max_tokens
645663 )
@@ -649,12 +667,6 @@ def _from_proto(data):
649667 out .capture_name = None if data .capture_name == "" else data .capture_name
650668 out .recursive = data .recursive
651669 return out
652-
653- @staticmethod
654- def deserialize (data_bytes ):
655- data = _grammar_pb2 .Join ()
656- data .ParseFromString (data_bytes )
657- return Join ._from_proto (data )
658670
659671def string (value ):
660672 if isinstance (value , str ):
0 commit comments