Skip to content

Commit 8f0e9aa

Browse files
committed
Rename grammar classes
1 parent b87c1c1 commit 8f0e9aa

File tree

6 files changed

+89
-89
lines changed

6 files changed

+89
-89
lines changed

guidance/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import requests
88

99
from . import models
10-
from ._grammar import (Placeholder, StatefulFunction, StatelessFunction,
10+
from ._grammar import (Placeholder, RawFunction, GrammarFunction,
1111
Terminal, replace_grammar_node, string)
1212
from ._utils import load, strip_multiline_string_indents
1313
from ._server import Server
@@ -79,7 +79,7 @@ def wrapped(*args, **kwargs):
7979

8080
# otherwise must be stateful (which means we can't be inside a select() call)
8181
else:
82-
return StatefulFunction(f, args, kwargs)
82+
return RawFunction(f, args, kwargs)
8383

8484
# attach this as a method of the model class (if given)
8585
# if model is not None:

guidance/_grammar.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def deserialize(cls, serialized_grammar):
4242
raise NotImplementedError()
4343

4444

45-
class StatefulFunction(Function):
45+
class RawFunction(Function):
4646
__slots__ = ("f", "args", "kwargs")
4747

4848
def __init__(self, f, args, kwargs):
@@ -63,11 +63,11 @@ def __add__(model):
6363
model = self(model)
6464
if model is None:
6565
raise Exception(f"The guidance function `{self.f.__name__}` did not return a model object! You need to return an updated model object at the end of your guidance function.")
66-
if isinstance(other, StatelessFunction):
66+
if isinstance(other, GrammarFunction):
6767
return model + other
6868
else:
6969
return other(model)
70-
return StatefulFunction(__add__, [], {})
70+
return RawFunction(__add__, [], {})
7171

7272
def __radd__(self, other):
7373

@@ -76,14 +76,14 @@ def __radd__(self, other):
7676
return other + str(self)
7777

7878
def __radd__(model):
79-
if isinstance(other, StatelessFunction):
79+
if isinstance(other, GrammarFunction):
8080
model += other
8181
else:
8282
model = other(model)
8383
return self(model)
84-
return StatefulFunction(__radd__, [], {})
84+
return RawFunction(__radd__, [], {})
8585

86-
class StatelessFunction(Function):
86+
class GrammarFunction(Function):
8787
num_used_names = 0
8888

8989
def __add__(self, value):
@@ -96,7 +96,7 @@ def __add__(self, value):
9696
value = string(value)
9797

9898
# see if we can keep building a stateless grammar
99-
if isinstance(value, StatelessFunction):
99+
if isinstance(value, GrammarFunction):
100100
return Join([self, value])
101101

102102
# otherwise we let the stateful object handle things
@@ -113,19 +113,19 @@ def __radd__(self, value):
113113
value = string(value)
114114

115115
# see if we can keep building a stateless grammar
116-
if isinstance(value, StatelessFunction):
116+
if isinstance(value, GrammarFunction):
117117
return Join([value, self])
118118

119119
# otherwise we let the stateful object handle things
120120
else:
121121
return value.__add__(self)
122122

123123
def __getitem__(self, value):
124-
raise StatefulException("StatelessFunctions can't access state!")
124+
raise StatefulException("GrammarFunctions can't access state!")
125125

126126
@staticmethod
127127
def _new_name():
128-
num_used = StatelessFunction.num_used_names
128+
num_used = GrammarFunction.num_used_names
129129

130130
a_ord = ord('a')
131131

@@ -138,7 +138,7 @@ def _new_name():
138138
if num_used >= 17576:
139139
name = chr(a_ord + (num_used % 456976) // 17576) + name
140140

141-
StatelessFunction.num_used_names += 1
141+
GrammarFunction.num_used_names += 1
142142

143143
return name
144144

@@ -169,7 +169,7 @@ def _rec_create_index_map(self, index_map):
169169
def _rec_serialize(self, index_map, nodes):
170170
if self not in nodes:
171171
v = self._to_proto(index_map)
172-
node = _serialization_pb2.StatelessFunction()
172+
node = _serialization_pb2.GrammarFunction()
173173
if isinstance(self, Byte):
174174
node.byte.CopyFrom(v)
175175
elif isinstance(self, ByteRange):
@@ -217,7 +217,7 @@ def deserialize(cls, serialized_grammar):
217217

218218
return values[0] # the first element in the root node of the grammar
219219

220-
class Terminal(StatelessFunction):
220+
class Terminal(GrammarFunction):
221221
def match_byte(self, byte):
222222
pass # abstract
223223

@@ -358,7 +358,7 @@ def __add__(self, other):
358358
def __radd__(self, other):
359359
return self.__add__(other) # left vs right makes no difference since we are null
360360

361-
class ModelVariable(StatelessFunction):
361+
class ModelVariable(GrammarFunction):
362362
'''This represents a variable that will be read from the model object when this grammar is executed.
363363
364364
Note that the name is the name of the attribute on the model object this node
@@ -510,7 +510,7 @@ def _wrap_as_grammar(value):
510510
'''This takes whatever value was given and tries to turn in into a guidance grammar.'''
511511

512512
# if it is already a valid grammar we have no need to wrap it
513-
if isinstance(value, StatelessFunction):
513+
if isinstance(value, GrammarFunction):
514514
return value
515515

516516
# if it is already a valid grammar we have no need to wrap it
@@ -546,20 +546,20 @@ def _rec_hide(grammar):
546546
for g in grammar.values:
547547
_rec_hide(g)
548548

549-
class Placeholder(StatelessFunction):
549+
class Placeholder(GrammarFunction):
550550
__slots__ = tuple("nullable")
551551
def __init__(self):
552552
self.nullable = False
553553

554554

555-
class Join(StatelessFunction):
555+
class Join(GrammarFunction):
556556
__slots__ = ("nullable", "values", "name", "hidden", "commit_point", "capture_name", "max_tokens")
557557

558558
def __init__(self, values, name=None, max_tokens=100000000) -> None:
559559
values = [string(v) if isinstance(v, (str, bytes)) else v for v in values] # wrap raw strings
560560
self.nullable = all(getattr(v, "nullable", False) for v in values)
561561
self.values = [v for v in values if not isinstance(v, Null)]
562-
self.name = name if name is not None else StatelessFunction._new_name()
562+
self.name = name if name is not None else GrammarFunction._new_name()
563563
self.hidden = False
564564
self.commit_point = False
565565
self.capture_name = None
@@ -602,12 +602,12 @@ def _from_proto(data):
602602
return out
603603

604604

605-
class Select(StatelessFunction):
605+
class Select(GrammarFunction):
606606
__slots__ = ("nullable", "_values", "name", "hidden", "commit_point", "capture_name", "max_tokens", "recursive")
607607

608608
def __init__(self, values, capture_name=None, name=None, max_tokens=10000000, recursive=False) -> None:
609609
self.values = values
610-
self.name = name if name is not None else StatelessFunction._new_name()
610+
self.name = name if name is not None else GrammarFunction._new_name()
611611
self.hidden = False
612612
self.commit_point = False
613613
self.capture_name = capture_name
@@ -681,7 +681,7 @@ def select(options, name=None, list_append=False, recurse=False, skip_checks=Fal
681681
# TODO: also the full probabilites distribution over all items. We can implement this using the prob of the selected item by repeating the call, removing the selected item each time
682682
if not skip_checks:
683683
for i, value in enumerate(options):
684-
assert not isinstance(value, StatefulFunction), "You cannot select between stateful functions in the current guidance implementation!"
684+
assert not isinstance(value, RawFunction), "You cannot select between stateful functions in the current guidance implementation!"
685685
assert not isinstance(value, types.FunctionType), "Did you pass a function without calling it to select? You need to pass the results of a called guidance function to select."
686686
if isinstance(value, int) or isinstance(value, float):
687687
options[i] = str(value)
@@ -791,10 +791,10 @@ def str_to_grammar(value):
791791
# lm.suffix = parts[i+1]
792792
if is_id:
793793
call = _call_pool[part]
794-
if isinstance(call, StatelessFunction):
794+
if isinstance(call, GrammarFunction):
795795
partial_grammar += _call_pool[part]
796796
else:
797-
partial_grammar = StatefulFunction(lambda lm, g, call: call(lm + g), partial_grammar, _call_pool[part])
797+
partial_grammar = RawFunction(lambda lm, g, call: call(lm + g), partial_grammar, _call_pool[part])
798798
# lm += partial_grammar
799799
# lm = _call_pool[part](lm)
800800
# partial_grammar = _null_grammar

guidance/_serialization.proto

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ syntax = "proto3";
33
package guidance;
44

55
message Grammar {
6-
repeated StatelessFunction nodes = 1;
6+
repeated GrammarFunction nodes = 1;
77
}
88

99
message EngineCallResponse {
@@ -77,7 +77,7 @@ message Select {
7777
// }
7878
// }
7979

80-
message StatelessFunction {
80+
message GrammarFunction {
8181
oneof function_type {
8282
Join join = 1;
8383
Select select = 2;

0 commit comments

Comments
 (0)