Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a980e0a
pending ast
hudson-ai Apr 8, 2025
1cd3ab1
call _run before accessing state
hudson-ai Apr 8, 2025
07de0c3
poc of async model
hudson-ai Apr 10, 2025
eed1e59
put everything on top of async backend
hudson-ai Apr 15, 2025
fbbe722
make everything way more lazy
hudson-ai Apr 15, 2025
2df1c39
AsyncFunction
hudson-ai Apr 15, 2025
55ef91e
subtle -- blocks need to be applied before running
hudson-ai Apr 15, 2025
636e2ba
make sure that async + sync function application doesn't run sync ins…
hudson-ai Apr 15, 2025
032ef61
centralize run_async_maybe_in_thread logic
hudson-ai Apr 15, 2025
3f3ac5d
make engine interpreter async (well, as close as we can get right now)
hudson-ai Apr 15, 2025
d7d46ee
fix engine interpreter role start/end
hudson-ai Apr 15, 2025
d171106
change test_text_closer
hudson-ai Apr 15, 2025
dea8020
add some tests
hudson-ai Apr 16, 2025
7b3aae6
fix sync/async eval loop (recall -- async can be inside of sync)
hudson-ai Apr 16, 2025
7db5363
fix test
hudson-ai Apr 17, 2025
bb931f3
make our async re-entrant with greenlets
hudson-ai Apr 16, 2025
be03577
doc
hudson-ai Apr 17, 2025
763e10b
fix wrong comment and use higher-level asyncio run
hudson-ai Apr 17, 2025
63bb655
don't try to await_ in a thread
hudson-ai Apr 17, 2025
ffa5c6d
close the coro if we're not going to run it
hudson-ai Apr 17, 2025
ba547f6
black, isort, mypy
hudson-ai Apr 17, 2025
427f714
assert we actually get an exception when using sync accessors in asyn…
hudson-ai Apr 17, 2025
394d415
copy context vars into greenlet
hudson-ai Apr 21, 2025
70ee3f9
more tests
hudson-ai Apr 21, 2025
850ebd1
entry point decorator to help determine whether to run reentrant awai…
hudson-ai Apr 21, 2025
61c28ee
factor out run in bg thread
hudson-ai Apr 21, 2025
7698c26
clean up bridge a bit
hudson-ai Apr 21, 2025
ce3d7a0
make bg_async generic
hudson-ai Apr 21, 2025
4839742
move run in bg async to bridge
hudson-ai Apr 21, 2025
4d8d3ee
guidance/_bridge.py -> guidance/_reentrant_async.py
hudson-ai Apr 21, 2025
8c28873
clear active blocks in repeated function application
hudson-ai Apr 21, 2025
06c309c
make sure to clear pending blocks after function application
hudson-ai Apr 22, 2025
111efc0
add 'batched' entrypoints
hudson-ai Apr 22, 2025
edee836
Merge branch 'main' into greenlet_eval
hudson-ai Apr 22, 2025
bc389e4
add greenlet dependency
hudson-ai Apr 22, 2025
c9a922b
fix capture blocks
hudson-ai Apr 22, 2025
7a75071
fix smoke tests by adding str(lm) to trigger execution
hudson-ai Apr 22, 2025
788097b
visual async test -- don't rely on there being an existing event loop
hudson-ai Apr 22, 2025
aaa1191
lm.get_token_count()
hudson-ai Apr 22, 2025
07da88f
fix associativity test -- use str to run lm
hudson-ai Apr 22, 2025
58b4817
remove print
hudson-ai Apr 22, 2025
4ad5a14
remove leftover generic
hudson-ai Apr 23, 2025
d4fd6cb
remove leftover generic (missed some)
hudson-ai Apr 23, 2025
b7650b8
make ModelStream play nicely with lazy Model
hudson-ai Apr 24, 2025
9e12ff7
call str to trigger execution for tests (todo: hide metrics behind ac…
hudson-ai Apr 24, 2025
439355f
Fix attribution to greenletio
hudson-ai Apr 24, 2025
f280ad9
Remove double import
hudson-ai Apr 24, 2025
6edf276
refactor before merge
hudson-ai May 6, 2025
cb8c613
Merge branch 'main' into greenlet_eval
hudson-ai May 6, 2025
d4c8b0e
async openai
hudson-ai May 6, 2025
4ff3d3d
clean up openai a bit
hudson-ai May 6, 2025
4125e01
fix openai with Concatenate ast
hudson-ai May 6, 2025
5a649ed
bring vllm up to speed
hudson-ai May 6, 2025
31b0fc8
Merge branch 'main' into greenlet_eval
hudson-ai May 6, 2025
1650092
regain my sanity -- refactor Model __init__ with dataclass
hudson-ai May 6, 2025
3a29a83
restore vis support
hudson-ai May 6, 2025
4bd5c20
generalize interpreter run to yield InputAttr in addition to OutputAttr
hudson-ai May 6, 2025
14af49c
fix mixin order with openai
hudson-ai May 6, 2025
62950b4
fix openai audio gen
hudson-ai May 6, 2025
716014c
Merge branch 'main' into greenlet_eval
hudson-ai May 7, 2025
9202113
Merge branch 'main' into greenlet_eval
hudson-ai May 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 150 additions & 26 deletions guidance/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TYPE_CHECKING,
Any,
Callable,
AsyncIterable,
Iterator,
Optional,
Sequence,
Expand All @@ -17,7 +18,9 @@
from typing_extensions import assert_never

from ._parser import ByteParser, ByteParserException
from .trace import OutputAttr
from .trace import InputAttr, OutputAttr, RoleOpenerInput, RoleCloserInput

NodeAttr = Union[InputAttr, OutputAttr]

if TYPE_CHECKING:
from .models._base import Interpreter, State
Expand Down Expand Up @@ -116,13 +119,13 @@ def __call__(self, model):
return model

def __add__(self, other):
if not isinstance(other, (str, GrammarNode, Function)):
if not isinstance(other, (str, ASTNode, Function)):
return NotImplemented

if isinstance(other, str):
other = _parse_tags(other)

if isinstance(other, GrammarNode) and other.is_null:
if isinstance(other, ASTNode) and other.is_null:
return self

def __add__(model):
Expand All @@ -131,78 +134,199 @@ def __add__(model):
return Function(__add__, [], {})

def __radd__(self, other):
if not isinstance(other, (str, GrammarNode, Function)):
if not isinstance(other, (str, ASTNode, Function)):
return NotImplemented

if isinstance(other, str):
other = _parse_tags(other)

if isinstance(other, GrammarNode) and other.is_null:
if isinstance(other, ASTNode) and other.is_null:
return self

def __radd__(model):
return self(model + other)

return Function(__radd__, [], {})

@dataclass
class AsyncFunction(Tagged):
name: str = field(init=False)
f: Callable
args: tuple[Any, ...]
kwargs: dict[str, Any]

S = TypeVar("S", bound="State")
def __post_init__(self):
self.name = self.f.__name__

async def __call__(self, model):
model = await self.f(model, *self.args, **self.kwargs)
if model is None:
raise Exception(
f"The guidance function `{self.f.__name__}` did not return a model object! You need to return an updated model object at the end of your guidance function."
)
return model

def __add__(self, other):
if not isinstance(other, (str, ASTNode, Function, AsyncFunction)):
return NotImplemented

if isinstance(other, str):
other = _parse_tags(other)

if isinstance(other, ASTNode) and other.is_null:
return self

async def __add__(model):
return (await self(model)) + other

return AsyncFunction(__add__, [], {})

def __radd__(self, other):
if not isinstance(other, (str, ASTNode, Function, AsyncFunction)):
return NotImplemented

if isinstance(other, str):
other = _parse_tags(other)

if isinstance(other, ASTNode) and other.is_null:
return self

async def __radd__(model):
return await self(model + other)

return AsyncFunction(__radd__, [], {})


S = TypeVar("S", bound="State")

class ASTNode(ABC):
@abstractmethod
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]:
pass

def simplify(self) -> "ASTNode":
return self

@property
def is_null(self) -> bool:
return False

def __add__(self, other):
if isinstance(other, str):
other = _parse_tags(other)

if isinstance(other, ASTNode):
return Concatenate((self, other))

return NotImplemented

def __radd__(self, other):
if isinstance(other, str):
other = _parse_tags(other)

if isinstance(other, ASTNode):
return Concatenate((other, self))

return NotImplemented

@classmethod
def null(cls) -> "ASTNode":
return Concatenate(())

@dataclass(frozen=True)
class Concatenate(ASTNode):
nodes: tuple[ASTNode, ...]

async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]:
buffer: Optional[GrammarNode] = None
for child in self:
assert not isinstance(child, Concatenate) # iter should be flat
if isinstance(child, GrammarNode):
if buffer is None:
buffer = child
else:
buffer = buffer + child
else:
if buffer is not None:
async for attr in interpreter.run(buffer, **kwargs):
yield attr
buffer = None
async for attr in interpreter.run(child, **kwargs):
yield attr
if buffer is not None:
async for attr in interpreter.run(buffer, **kwargs):
yield attr

def __iter__(self) -> Iterator[ASTNode]:
for node in self.nodes:
if isinstance(node, Concatenate):
yield from node
else:
yield node

@dataclass
class RoleStart(ASTNode):
role: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
return interpreter._role_start(self, **kwargs)
async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]:
yield RoleOpenerInput(name=self.role)
async for output_attr in interpreter._role_start(self, **kwargs):
yield output_attr


@dataclass
class RoleEnd(ASTNode):
role: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
return interpreter._role_end(self, **kwargs)
async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]:
yield RoleCloserInput(name=self.role)
async for output_attr in interpreter._role_end(self, **kwargs):
yield output_attr


@dataclass
class CaptureStart(ASTNode):
name: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.capture_start(self, **kwargs)

@dataclass
class CaptureEnd(ASTNode):
name: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.capture_end(self, **kwargs)

@dataclass
class ImageBlob(ASTNode):
data: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.image_blob(self, **kwargs)


@dataclass
class ImageUrl(ASTNode):
url: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.image_url(self, **kwargs)


@dataclass
class AudioBlob(ASTNode):
data: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.audio_blob(self, **kwargs)


class GenAudio(ASTNode):
def __init__(self, kwargs: dict[str, Any]):
self.kwargs = kwargs

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.gen_audio(self, **kwargs)


Expand Down Expand Up @@ -317,15 +441,15 @@ class LiteralNode(GrammarNode):
def is_null(self) -> bool:
return self.value == ""

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.text(self, **kwargs)


@dataclass(frozen=True)
class RegexNode(GrammarNode):
regex: Optional[str]

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.regex(self, **kwargs)


Expand Down Expand Up @@ -353,7 +477,7 @@ def simplify(self) -> "GrammarNode":
def children(self) -> Sequence["GrammarNode"]:
return self.alternatives

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.select(self, **kwargs)


Expand All @@ -376,7 +500,7 @@ def simplify(self) -> "GrammarNode":
def children(self) -> Sequence["GrammarNode"]:
return self.nodes

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.join(self, **kwargs)


Expand All @@ -402,7 +526,7 @@ def children(self) -> Sequence["GrammarNode"]:
def simplify(self) -> GrammarNode:
return RepeatNode(self.node.simplify(), self.min, self.max)

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.repeat(self, **kwargs)


Expand All @@ -415,7 +539,7 @@ def is_terminal(self) -> bool:
# this can be used as part of bigger regexes
return True

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.substring(self, **kwargs)


Expand Down Expand Up @@ -466,7 +590,7 @@ def is_terminal(self) -> bool:
def children(self) -> Sequence["GrammarNode"]:
return (self.value,)

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.rule(self, **kwargs)

@dataclass(frozen=True, eq=False)
Expand All @@ -485,7 +609,7 @@ def is_terminal(self) -> bool:
# so it should never be terminal.
return False

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
if self.target is None:
raise ValueError("RuleRefNode target not set")
return interpreter.rule(self.target)
Expand All @@ -501,23 +625,23 @@ class SubgrammarNode(BaseSubgrammarNode):
body: GrammarNode
skip_regex: Optional[str] = None

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.subgrammar(self, **kwargs)


@dataclass(frozen=True, eq=False)
class JsonNode(BaseSubgrammarNode):
schema: dict[str, Any]

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.json(self, **kwargs)


@dataclass(frozen=True, eq=False)
class LarkNode(BaseSubgrammarNode):
lark_grammar: str

def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]:
def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]:
return interpreter.lark(self, **kwargs)

class LarkSerializer:
Expand Down
11 changes: 6 additions & 5 deletions guidance/_bg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import asyncio
import threading
from asyncio import AbstractEventLoop, Future, Task
from typing import Tuple, Coroutine
from typing import Coroutine, Any, TypeVar

T = TypeVar("T")

def _start_asyncio_loop(loop: AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()


def _asyncio_background_thread() -> Tuple[threading.Thread, AbstractEventLoop]:
def _asyncio_background_thread() -> tuple[threading.Thread, AbstractEventLoop]:
loop = asyncio.new_event_loop()
thread = threading.Thread(target=_start_asyncio_loop, args=(loop,))
thread.daemon = True
Expand All @@ -29,7 +30,7 @@ def __init__(self):
self._loop = None
self._thread = None

def _thread_and_loop(self) -> Tuple[threading.Thread, AbstractEventLoop]:
def _thread_and_loop(self) -> tuple[threading.Thread, AbstractEventLoop]:
if self._loop is None:
self._thread, self._loop = _asyncio_background_thread()
self._thread.start()
Expand All @@ -41,7 +42,7 @@ def call_soon_threadsafe(self, cb, *args, context = None):
_, loop = self._thread_and_loop()
return loop.call_soon_threadsafe(cb, *args, context=context)

def run_async_coroutine(self, coroutine: Coroutine) -> Future:
def run_async_coroutine(self, coroutine: Coroutine[Any, Any, T]) -> Future[T]:
""" Runs an asynchronous coroutine in the visual thread.

Args:
Expand All @@ -55,7 +56,7 @@ def run_async_coroutine(self, coroutine: Coroutine) -> Future:
return future

@staticmethod
async def async_task(coroutine: Coroutine) -> Task:
async def async_task(coroutine: Coroutine[Any, Any, T]) -> Task[T]:
""" Creates an asyncio task from coroutine.

Args:
Expand Down
Loading