Skip to content

Commit 642ce44

Browse files
authored
feat: make PDL functions callable as Python and jinja functions (#1070)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 49f86a3 commit 642ce44

File tree

5 files changed

+213
-29
lines changed

5 files changed

+213
-29
lines changed

src/pdl/pdl_ast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
RootModel,
2525
TypeAdapter,
2626
)
27-
from pydantic.json_schema import SkipJsonSchema
2827
from typing_extensions import TypeAliasType
2928

3029
from .pdl_context import PDLContext
@@ -319,6 +318,7 @@ class Block(BaseModel):
319318
extra="forbid",
320319
use_attribute_docstrings=True,
321320
arbitrary_types_allowed=True,
321+
validate_by_name=True,
322322
)
323323

324324
description: Optional[str] = None
@@ -398,9 +398,6 @@ class FunctionBlock(LeafBlock):
398398
"""Function signature computed from the function definition.
399399
"""
400400

401-
# Field for internal use
402-
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(default=None, repr=False)
403-
404401

405402
class CallBlock(LeafBlock):
406403
"""Calling a function."""

src/pdl/pdl_interpreter.py

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
warnings.filterwarnings("ignore", "Valid config keys have changed in V2")
1818

1919
from pathlib import Path # noqa: E402
20-
from typing import Any, Generator, Optional, Sequence, TypeVar # noqa: E402
20+
from typing import Any, Generator, Generic, Optional, Sequence, TypeVar # noqa: E402
2121

2222
import httpx # noqa: E402
2323
import json_repair # noqa: E402
@@ -33,6 +33,7 @@
3333
from jinja2.nodes import TemplateData # noqa: E402
3434
from jinja2.runtime import Undefined # noqa: E402
3535
from pydantic import BaseModel, ConfigDict, Field # noqa: E402
36+
from pydantic.json_schema import SkipJsonSchema # noqa: E402
3637

3738
from .pdl_ast import ( # noqa: E402
3839
AdvancedBlockType,
@@ -131,19 +132,39 @@
131132
empty_scope: ScopeType = PdlDict({"pdl_context": DependentContext([])})
132133

133134

135+
RefT = TypeVar("RefT")
136+
137+
138+
class Ref(Generic[RefT]):
139+
def __init__(self, ref: RefT):
140+
self.ref = ref
141+
142+
134143
class InterpreterState(BaseModel):
135144
model_config = ConfigDict(arbitrary_types_allowed=True)
136145

137146
yield_result: bool = False
147+
"""Stream the result on the standard output as soon as possible."""
138148
yield_background: bool = False
149+
"""Stream the toplevel pdl_context on the standard output as soon as possible."""
139150
batch: int = 1
140-
# batch=0: streaming
141-
# batch=1: call to generate with `input`
151+
"""
152+
Stream the output of the LLM
153+
- batch=0: streaming
154+
- batch=1: call to generate with `input`
155+
"""
142156
role: RoleType = "user"
157+
"""Current role to add messages in the context."""
143158
cwd: Path = Path.cwd()
144-
# background_tasks = {}
159+
"""Current working directory."""
145160
id_stack: list[str] = []
161+
"""Id generator for the UI."""
162+
163+
# The following are shared variable that should be modified by side effects
146164
event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread)
165+
"""Event loop to schedule LLM calls."""
166+
current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([]))
167+
"""Current value of the context set at the beginning of the execution of the block."""
147168

148169
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
149170
return self.model_copy(update={"yield_result": b})
@@ -168,6 +189,19 @@ def with_pop(self: "InterpreterState") -> "InterpreterState":
168189
return self.model_copy(update={"id_stack": stack})
169190

170191

192+
class ClosureBlock(FunctionBlock):
193+
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(repr=False)
194+
pdl__state: SkipJsonSchema[InterpreterState] = Field(repr=False)
195+
196+
def __call__(self, **kwds):
197+
state = self.pdl__state.with_yield_result(False).with_yield_background(False)
198+
current_context = state.current_pdl_context.ref
199+
result, _, _ = execute_call(
200+
state, current_context, self, kwds, empty_block_location
201+
)
202+
return result
203+
204+
171205
def generate(
172206
pdl_file: str | Path,
173207
state: Optional[InterpreterState],
@@ -246,6 +280,7 @@ def process_block(
246280
background: LazyMessages
247281
trace: BlockType
248282
try:
283+
state.current_pdl_context.ref = scope["pdl_context"] # type: ignore
249284
if not isinstance(block, Block):
250285
start = time.time_ns()
251286
try:
@@ -436,7 +471,7 @@ def process_advanced_block(
436471
result.result()
437472
break
438473
except Exception as exc:
439-
err_msg = exc.args[0]
474+
err_msg = traceback.format_exc()
440475
do_retry = (
441476
block.retry
442477
and trial_idx + 1 < trial_total
@@ -915,7 +950,23 @@ def process_block_body(
915950
result, background, scope, trace = process_import(state, scope, block, loc)
916951

917952
case FunctionBlock():
918-
closure = block.model_copy()
953+
closure = ClosureBlock( # pyright: ignore
954+
description=block.description,
955+
spec=block.spec,
956+
defs=block.defs,
957+
def_=block.def_, # pyright: ignore
958+
contribute=block.contribute,
959+
parser=block.parser,
960+
fallback=block.fallback,
961+
retry=block.retry,
962+
trace_error_on_retry=block.trace_error_on_retry,
963+
role=block.role,
964+
function=block.function,
965+
return_=block.return_, # pyright: ignore
966+
pdl__location=loc,
967+
pdl__scope=None,
968+
pdl__state=state,
969+
)
919970
if block.def_ is not None:
920971
scope = scope | {block.def_: closure}
921972
closure.pdl__scope = scope
@@ -1872,7 +1923,7 @@ def process_call(
18721923
background: LazyMessages = DependentContext([])
18731924
args, block = process_expr_of(block, "args", scope, loc)
18741925
closure, _ = process_expr_of(block, "call", scope, loc)
1875-
if not isinstance(closure, FunctionBlock):
1926+
if not isinstance(closure, ClosureBlock):
18761927
msg = f"Type error: {block.call} is of type {type(closure)} but should be a function."
18771928
if isinstance(closure, str) and isinstance(scope.get(closure), FunctionBlock):
18781929
msg += " You might want to call `${ " + str(block.call) + " }`."
@@ -1890,12 +1941,28 @@ def process_call(
18901941
loc=args_loc,
18911942
trace=block.model_copy(),
18921943
)
1944+
current_context = scope.data["pdl_context"]
1945+
try:
1946+
result, background, call_trace = execute_call(
1947+
state, current_context, closure, args, loc
1948+
)
1949+
except PDLRuntimeError as exc:
1950+
raise PDLRuntimeError(
1951+
exc.message,
1952+
loc=exc.loc or closure.pdl__location,
1953+
trace=block.model_copy(update={"pdl__trace": exc.pdl__trace}),
1954+
) from exc
1955+
trace = block.model_copy(update={"pdl__trace": call_trace})
1956+
return result, background, scope, trace
1957+
1958+
1959+
def execute_call(state, current_context, closure, args, loc):
18931960
if "pdl_context" in args:
1894-
args["pdl_context"] = deserialize(args["pdl_context"])
1961+
args = args | {"pdl_context": deserialize(args["pdl_context"])}
18951962
f_body = closure.return_
18961963
f_scope = (
18971964
(closure.pdl__scope or PdlDict({}))
1898-
| PdlDict({"pdl_context": scope.data["pdl_context"]})
1965+
| PdlDict({"pdl_context": current_context})
18991966
| PdlDict((args or {}))
19001967
)
19011968
if closure.pdl__location is not None:
@@ -1906,27 +1973,19 @@ def process_call(
19061973
)
19071974
else:
19081975
fun_loc = empty_block_location
1909-
try:
1910-
result, background, _, f_trace = process_block(state, f_scope, f_body, fun_loc)
1911-
except PDLRuntimeError as exc:
1912-
raise PDLRuntimeError(
1913-
exc.message,
1914-
loc=exc.loc or fun_loc,
1915-
trace=block.model_copy(update={"pdl__trace": exc.pdl__trace}),
1916-
) from exc
1917-
trace = block.model_copy(update={"pdl__trace": f_trace})
1976+
result, background, _, f_trace = process_block(state, f_scope, f_body, fun_loc)
19181977
if closure.spec is not None:
19191978
result = lazy_apply(
19201979
lambda r: result_with_type_checking(
19211980
r,
19221981
closure.spec,
1923-
f"Type errors in result of function call to {block.call}:",
1924-
loc,
1925-
trace,
1982+
f"Type errors in result of the function{' ' + closure.signature.get('name', '') if closure.signature is not None else ''}:",
1983+
fun_loc,
1984+
f_trace,
19261985
),
19271986
result,
19281987
)
1929-
return result, background, scope, trace
1988+
return result, background, f_trace
19301989

19311990

19321991
def process_input(

tests/test_function.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pdl.pdl import exec_dict, exec_file
1+
from pdl.pdl import exec_dict, exec_file, exec_str
22

33
hello_def = {
44
"def": "hello",
@@ -126,3 +126,111 @@ def test_call_expression_args():
126126
result
127127
== "FN::get_current_stock:: 'Simple call!'\n{'product_name': 'from_object'}\nFN::get_current_stock:: 'from_object'\n"
128128
)
129+
130+
131+
def test_call_from_code_01():
132+
prog = """
133+
defs:
134+
f:
135+
function:
136+
x:
137+
y:
138+
return:
139+
${x + 1}
140+
array:
141+
- call: ${f}
142+
args:
143+
x: 1
144+
y: 1
145+
- ${ f(x=1, y=2) }
146+
- lang: python
147+
code:
148+
result = f(x=1, y=2)
149+
"""
150+
result = exec_str(prog)
151+
assert result == [2, 2, 2]
152+
153+
154+
def test_call_from_code_02():
155+
prog = """
156+
defs:
157+
f:
158+
function:
159+
return:
160+
${pdl_context}
161+
lastOf:
162+
- Hello
163+
- context: independent
164+
array:
165+
- call: ${f}
166+
- ${ f() }
167+
- lang: python
168+
code:
169+
result = f()
170+
"""
171+
result = exec_str(prog)
172+
assert [ctx.serialize("litellm") for ctx in result] == [
173+
[{"role": "user", "content": "Hello", "pdl__defsite": "lastOf.0"}],
174+
[{"role": "user", "content": "Hello", "pdl__defsite": "lastOf.0"}],
175+
[{"role": "user", "content": "Hello", "pdl__defsite": "lastOf.0"}],
176+
]
177+
178+
179+
def test_call_from_code_03():
180+
prog = """
181+
defs:
182+
f:
183+
function:
184+
return:
185+
${pdl_context}
186+
lastOf:
187+
- Hello
188+
- context: independent
189+
array:
190+
- call: ${f}
191+
args:
192+
pdl_context: []
193+
- ${ f(pdl_context=[]) }
194+
- lang: python
195+
code:
196+
result = f(pdl_context=[])
197+
"""
198+
result = exec_str(prog)
199+
assert [ctx.serialize("litellm") for ctx in result] == [
200+
[],
201+
[],
202+
[],
203+
]
204+
205+
206+
def test_call_from_code_04():
207+
prog = """
208+
defs:
209+
f:
210+
function:
211+
return:
212+
lastOf:
213+
- How are you?
214+
- Bye
215+
lastOf:
216+
- Hello
217+
- context: independent
218+
array:
219+
- text:
220+
- call: ${f}
221+
- ${pdl_context}
222+
- text:
223+
- ${f()}
224+
- ${pdl_context}
225+
- text:
226+
- lang: python
227+
code:
228+
result = f()
229+
- ${pdl_context}
230+
"""
231+
result = exec_str(prog)
232+
assert result == [
233+
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'How are you?', 'pdl__defsite': 'lastOf.1.array.0.text.0.call.lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.0.text.0.call.lastOf.1'}]",
234+
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.1.text.0'}]",
235+
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.2.text.0.code'}]",
236+
]

tests/test_line_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def test_line13(capsys: CaptureFixture[str]):
168168
"file": "tests/data/line/hello14.pdl",
169169
"errors": [
170170
"",
171-
"tests/data/line/hello14.pdl:25 - Type errors in result of function call to ${ translate }:",
172-
"tests/data/line/hello14.pdl:25 - Bonjour le monde! should be of type <class 'int'>",
171+
"tests/data/line/hello14.pdl:16 - Type errors in result of the function translate:",
172+
"tests/data/line/hello14.pdl:16 - Bonjour le monde! should be of type <class 'int'>",
173173
],
174174
}
175175

tests/test_type_checking.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,23 @@ def test_deprecated(capsys: pytest.CaptureFixture[str]):
605605
do_test_stderr(
606606
capsys, prog, ["Deprecated type syntax: use integer instead of int.", ""]
607607
)
608+
609+
610+
def test_function_call_jinja_19():
611+
prog = """
612+
defs:
613+
f:
614+
function:
615+
x:
616+
y:
617+
return:
618+
${x + 1}
619+
array:
620+
- call: ${f}
621+
args:
622+
x: 1
623+
y: 1
624+
- ${ f(1, 2) }
625+
"""
626+
with pytest.raises(PDLRuntimeError):
627+
exec_str(prog)

0 commit comments

Comments
 (0)