Skip to content

Commit 4b5f442

Browse files
authored
Protobuf round trip unit test (#938)
Adds a unit test that does a round-trip serialization then deserialization of a grammar. This way we can cut out the server entirely, hopefully resulting in a more reliable test.
1 parent d6cb286 commit 4b5f442

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

tests/unit/test_protobuf.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
from itertools import chain
3+
4+
from guidance import (
5+
byte_range,
6+
char_range,
7+
commit_point,
8+
select,
9+
string,
10+
token_limit,
11+
with_temperature,
12+
)
13+
from guidance._grammar import (
14+
Byte,
15+
ByteRange,
16+
GrammarFunction,
17+
Join,
18+
ModelVariable,
19+
Select,
20+
)
21+
22+
23+
def compare_grammars(g1: GrammarFunction, g2: GrammarFunction) -> bool:
24+
"""Recursively compare two GrammarFunction objects for equivalence."""
25+
26+
if type(g1) != type(g2):
27+
return False
28+
29+
# Compare attributes based on type
30+
if isinstance(g1, (Byte, ByteRange, ModelVariable)):
31+
slots = chain.from_iterable(getattr(cls, '__slots__', []) for cls in type(g1).mro())
32+
return all(getattr(g1, slot) == getattr(g2, slot) for slot in slots)
33+
elif isinstance(g1, (Join, Select)):
34+
slots = chain.from_iterable(getattr(cls, '__slots__', []) for cls in type(g1).mro())
35+
return (all(getattr(g1, slot) == getattr(g2, slot) for slot in slots if 'values' not in slot)
36+
and len(g1.values) == len(g2.values) # Check both have same number of child nodes
37+
and all(compare_grammars(v1, v2) for v1, v2 in zip(g1.values, g2.values)) # Recursively compare child nodes
38+
)
39+
else:
40+
raise ValueError(f"Unsupported grammar type: {type(g1)}")
41+
42+
43+
@pytest.mark.parametrize(
44+
"grammar",
45+
[
46+
string("Hello, world!"),
47+
Byte(b"a"),
48+
byte_range(b"\x00", b"\xff"),
49+
char_range("a", "z"),
50+
select(["option1", "option2", "option3"]),
51+
commit_point(string("commit"), hidden=True),
52+
token_limit(string("limited"), max_tokens=5),
53+
with_temperature(string("temp"), temperature=0.5),
54+
ModelVariable("my_variable"),
55+
Join([string("part1"), string("part2")]),
56+
select(
57+
[
58+
string("option1"),
59+
Join([string("part1"), string("part2")]),
60+
]
61+
),
62+
],
63+
)
64+
def test_grammar_protobuf_roundtrip(grammar: GrammarFunction):
65+
"""Test that grammars can be round-tripped through protobuf serialization."""
66+
serialized_grammar = grammar.serialize()
67+
deserialized_grammar = GrammarFunction.deserialize(serialized_grammar)
68+
69+
# Recursively compare the grammars
70+
assert compare_grammars(
71+
grammar, deserialized_grammar
72+
), f"Deserialized grammar does not match original:\nOriginal: {grammar}\nDeserialized: {deserialized_grammar}\n"
73+

0 commit comments

Comments
 (0)