17
17
warnings .filterwarnings ("ignore" , "Valid config keys have changed in V2" )
18
18
19
19
from pathlib import Path # noqa: E402
20
- from typing import Any , Generator , Optional , Sequence , TypeVar # noqa: E402
20
+ from typing import Any , Generator , Generic , Optional , Sequence , TypeVar # noqa: E402
21
21
22
22
import httpx # noqa: E402
23
23
import json_repair # noqa: E402
33
33
from jinja2 .nodes import TemplateData # noqa: E402
34
34
from jinja2 .runtime import Undefined # noqa: E402
35
35
from pydantic import BaseModel , ConfigDict , Field # noqa: E402
36
+ from pydantic .json_schema import SkipJsonSchema # noqa: E402
36
37
37
38
from .pdl_ast import ( # noqa: E402
38
39
AdvancedBlockType ,
131
132
empty_scope : ScopeType = PdlDict ({"pdl_context" : DependentContext ([])})
132
133
133
134
135
+ RefT = TypeVar ("RefT" )
136
+
137
+
138
+ class Ref (Generic [RefT ]):
139
+ def __init__ (self , ref : RefT ):
140
+ self .ref = ref
141
+
142
+
134
143
class InterpreterState (BaseModel ):
135
144
model_config = ConfigDict (arbitrary_types_allowed = True )
136
145
137
146
yield_result : bool = False
147
+ """Stream the result on the standard output as soon as possible."""
138
148
yield_background : bool = False
149
+ """Stream the toplevel pdl_context on the standard output as soon as possible."""
139
150
batch : int = 1
140
- # batch=0: streaming
141
- # batch=1: call to generate with `input`
151
+ """
152
+ Stream the output of the LLM
153
+ - batch=0: streaming
154
+ - batch=1: call to generate with `input`
155
+ """
142
156
role : RoleType = "user"
157
+ """Current role to add messages in the context."""
143
158
cwd : Path = Path .cwd ()
144
- # background_tasks = {}
159
+ """Current working directory."""
145
160
id_stack : list [str ] = []
161
+ """Id generator for the UI."""
162
+
163
+ # The following are shared variable that should be modified by side effects
146
164
event_loop : AbstractEventLoop = Field (default_factory = create_event_loop_thread )
165
+ """Event loop to schedule LLM calls."""
166
+ current_pdl_context : Ref [LazyMessages ] = Ref (DependentContext ([]))
167
+ """Current value of the context set at the beginning of the execution of the block."""
147
168
148
169
def with_yield_result (self : "InterpreterState" , b : bool ) -> "InterpreterState" :
149
170
return self .model_copy (update = {"yield_result" : b })
@@ -168,6 +189,19 @@ def with_pop(self: "InterpreterState") -> "InterpreterState":
168
189
return self .model_copy (update = {"id_stack" : stack })
169
190
170
191
192
+ class ClosureBlock (FunctionBlock ):
193
+ pdl__scope : SkipJsonSchema [Optional [ScopeType ]] = Field (repr = False )
194
+ pdl__state : SkipJsonSchema [InterpreterState ] = Field (repr = False )
195
+
196
+ def __call__ (self , ** kwds ):
197
+ state = self .pdl__state .with_yield_result (False ).with_yield_background (False )
198
+ current_context = state .current_pdl_context .ref
199
+ result , _ , _ = execute_call (
200
+ state , current_context , self , kwds , empty_block_location
201
+ )
202
+ return result
203
+
204
+
171
205
def generate (
172
206
pdl_file : str | Path ,
173
207
state : Optional [InterpreterState ],
@@ -246,6 +280,7 @@ def process_block(
246
280
background : LazyMessages
247
281
trace : BlockType
248
282
try :
283
+ state .current_pdl_context .ref = scope ["pdl_context" ] # type: ignore
249
284
if not isinstance (block , Block ):
250
285
start = time .time_ns ()
251
286
try :
@@ -436,7 +471,7 @@ def process_advanced_block(
436
471
result .result ()
437
472
break
438
473
except Exception as exc :
439
- err_msg = exc . args [ 0 ]
474
+ err_msg = traceback . format_exc ()
440
475
do_retry = (
441
476
block .retry
442
477
and trial_idx + 1 < trial_total
@@ -915,7 +950,23 @@ def process_block_body(
915
950
result , background , scope , trace = process_import (state , scope , block , loc )
916
951
917
952
case FunctionBlock ():
918
- closure = block .model_copy ()
953
+ closure = ClosureBlock ( # pyright: ignore
954
+ description = block .description ,
955
+ spec = block .spec ,
956
+ defs = block .defs ,
957
+ def_ = block .def_ , # pyright: ignore
958
+ contribute = block .contribute ,
959
+ parser = block .parser ,
960
+ fallback = block .fallback ,
961
+ retry = block .retry ,
962
+ trace_error_on_retry = block .trace_error_on_retry ,
963
+ role = block .role ,
964
+ function = block .function ,
965
+ return_ = block .return_ , # pyright: ignore
966
+ pdl__location = loc ,
967
+ pdl__scope = None ,
968
+ pdl__state = state ,
969
+ )
919
970
if block .def_ is not None :
920
971
scope = scope | {block .def_ : closure }
921
972
closure .pdl__scope = scope
@@ -1872,7 +1923,7 @@ def process_call(
1872
1923
background : LazyMessages = DependentContext ([])
1873
1924
args , block = process_expr_of (block , "args" , scope , loc )
1874
1925
closure , _ = process_expr_of (block , "call" , scope , loc )
1875
- if not isinstance (closure , FunctionBlock ):
1926
+ if not isinstance (closure , ClosureBlock ):
1876
1927
msg = f"Type error: { block .call } is of type { type (closure )} but should be a function."
1877
1928
if isinstance (closure , str ) and isinstance (scope .get (closure ), FunctionBlock ):
1878
1929
msg += " You might want to call `${ " + str (block .call ) + " }`."
@@ -1890,12 +1941,28 @@ def process_call(
1890
1941
loc = args_loc ,
1891
1942
trace = block .model_copy (),
1892
1943
)
1944
+ current_context = scope .data ["pdl_context" ]
1945
+ try :
1946
+ result , background , call_trace = execute_call (
1947
+ state , current_context , closure , args , loc
1948
+ )
1949
+ except PDLRuntimeError as exc :
1950
+ raise PDLRuntimeError (
1951
+ exc .message ,
1952
+ loc = exc .loc or closure .pdl__location ,
1953
+ trace = block .model_copy (update = {"pdl__trace" : exc .pdl__trace }),
1954
+ ) from exc
1955
+ trace = block .model_copy (update = {"pdl__trace" : call_trace })
1956
+ return result , background , scope , trace
1957
+
1958
+
1959
+ def execute_call (state , current_context , closure , args , loc ):
1893
1960
if "pdl_context" in args :
1894
- args [ "pdl_context" ] = deserialize (args ["pdl_context" ])
1961
+ args = args | { "pdl_context" : deserialize (args ["pdl_context" ])}
1895
1962
f_body = closure .return_
1896
1963
f_scope = (
1897
1964
(closure .pdl__scope or PdlDict ({}))
1898
- | PdlDict ({"pdl_context" : scope . data [ "pdl_context" ] })
1965
+ | PdlDict ({"pdl_context" : current_context })
1899
1966
| PdlDict ((args or {}))
1900
1967
)
1901
1968
if closure .pdl__location is not None :
@@ -1906,27 +1973,19 @@ def process_call(
1906
1973
)
1907
1974
else :
1908
1975
fun_loc = empty_block_location
1909
- try :
1910
- result , background , _ , f_trace = process_block (state , f_scope , f_body , fun_loc )
1911
- except PDLRuntimeError as exc :
1912
- raise PDLRuntimeError (
1913
- exc .message ,
1914
- loc = exc .loc or fun_loc ,
1915
- trace = block .model_copy (update = {"pdl__trace" : exc .pdl__trace }),
1916
- ) from exc
1917
- trace = block .model_copy (update = {"pdl__trace" : f_trace })
1976
+ result , background , _ , f_trace = process_block (state , f_scope , f_body , fun_loc )
1918
1977
if closure .spec is not None :
1919
1978
result = lazy_apply (
1920
1979
lambda r : result_with_type_checking (
1921
1980
r ,
1922
1981
closure .spec ,
1923
- f"Type errors in result of function call to { block . call } :" ,
1924
- loc ,
1925
- trace ,
1982
+ f"Type errors in result of the function{ ' ' + closure . signature . get ( 'name' , '' ) if closure . signature is not None else '' } :" ,
1983
+ fun_loc ,
1984
+ f_trace ,
1926
1985
),
1927
1986
result ,
1928
1987
)
1929
- return result , background , scope , trace
1988
+ return result , background , f_trace
1930
1989
1931
1990
1932
1991
def process_input (
0 commit comments