Skip to content

Commit c1e0a71

Browse files
committed
Merge branch 'main' into anonymous
2 parents 79ff622 + 12ff796 commit c1e0a71

File tree

11 files changed

+1656
-294
lines changed

11 files changed

+1656
-294
lines changed

pdl-live-react/src/pdl_ast.d.ts

Lines changed: 638 additions & 48 deletions
Large diffs are not rendered by default.

src/pdl/pdl-schema.json

Lines changed: 687 additions & 193 deletions
Large diffs are not rendered by default.

src/pdl/pdl_ast.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BeforeValidator,
2020
ConfigDict,
2121
Field,
22+
Json,
2223
RootModel,
2324
TypeAdapter,
2425
)
@@ -228,23 +229,32 @@ class OptionalPdlType(PdlType):
228229
optional: "PdlTypeType"
229230

230231

232+
class JsonSchemaTypePdlType(PdlType):
233+
"""Json Schema type"""
234+
235+
model_config = ConfigDict(extra="allow")
236+
type: str | list[str]
237+
238+
231239
class ObjPdlType(PdlType):
232-
"""Optional type."""
240+
"""Object type."""
233241

234242
obj: Optional[dict[str, "PdlTypeType"]]
235243

236244

237245
PdlTypeType = TypeAliasType(
238246
"PdlTypeType",
239247
Annotated[
240-
"Union[BasePdlType," # pyright: ignore
248+
"Union[None," # pyright: ignore
249+
" BasePdlType,"
241250
" EnumPdlType,"
242251
" StrPdlType,"
243252
" FloatPdlType,"
244253
" IntPdlType,"
245254
" ListPdlType,"
246255
" list['PdlTypeType'],"
247256
" OptionalPdlType,"
257+
" JsonSchemaTypePdlType,"
248258
" ObjPdlType,"
249259
" dict[str, 'PdlTypeType']]",
250260
Field(union_mode="left_to_right"),
@@ -261,7 +271,7 @@ class Parser(BaseModel):
261271
description: Optional[str] = None
262272
"""Documentation associated to the parser.
263273
"""
264-
spec: Optional[PdlTypeType] = None
274+
spec: PdlTypeType = None
265275
"""Expected type of the parsed value.
266276
"""
267277

@@ -340,7 +350,7 @@ class Block(BaseModel):
340350
description: Optional[str] = None
341351
"""Documentation associated to the block.
342352
"""
343-
spec: Optional[PdlTypeType] = None
353+
spec: PdlTypeType = None
344354
"""Type specification of the result of the block.
345355
"""
346356
defs: dict[str, "BlockType"] = {}
@@ -360,6 +370,12 @@ class Block(BaseModel):
360370
fallback: Optional["BlockType"] = None
361371
"""Block to execute in case of error.
362372
"""
373+
retry: Optional[int] = None
374+
"""The maximum number of times to retry when an error occurs within a block.
375+
"""
376+
trace_error_on_retry: Optional[bool] | str = None
377+
"""Whether to add the errors while retrying to the trace. Set this to true to use retry feature for multiple LLM trials.
378+
"""
363379
role: RoleType = None
364380
"""Role associated to the block and sub-blocks.
365381
Typical roles are `system`, `user`, and `assistant`,
@@ -402,8 +418,12 @@ class FunctionBlock(LeafBlock):
402418
"""Functions parameters with their types.
403419
"""
404420
returns: "BlockType" = Field(..., alias="return")
405-
"""Body of the function
421+
"""Body of the function.
406422
"""
423+
signature: Optional[Json] = None
424+
"""Function signature computed from the function definition.
425+
"""
426+
407427
# Field for internal use
408428
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(default=None, repr=False)
409429

src/pdl/pdl_dumper.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
IntPdlType,
3333
JoinText,
3434
JoinType,
35+
JsonSchemaTypePdlType,
3536
LastOfBlock,
3637
ListPdlType,
3738
ListPdlTypeConstraints,
@@ -124,6 +125,12 @@ def block_to_dict( # noqa: C901
124125
d["defs"] = {
125126
x: block_to_dict(b, json_compatible) for x, b in block.defs.items()
126127
}
128+
if block.retry is not None:
129+
d["retry"] = expr_to_dict(block.retry, json_compatible)
130+
if block.trace_error_on_retry is not None:
131+
d["trace_error_on_retry"] = expr_to_dict(
132+
block.trace_error_on_retry, json_compatible
133+
)
127134
if isinstance(block, StructuredBlock):
128135
d["context"] = block.context
129136

@@ -310,8 +317,10 @@ def expr_to_dict(expr: ExpressionType, json_compatible: bool):
310317

311318

312319
def type_to_dict(t: PdlTypeType):
313-
d: str | list | dict
320+
d: None | str | list | dict
314321
match t:
322+
case None:
323+
d = None
315324
case "null" | "bool" | "str" | "float" | "int" | "list" | "obj":
316325
d = t
317326
case EnumPdlType():
@@ -377,6 +386,8 @@ def type_to_dict(t: PdlTypeType):
377386
assert False, "list must have only one element"
378387
case OptionalPdlType():
379388
d = {"optional": type_to_dict(t.optional)}
389+
case JsonSchemaTypePdlType():
390+
d = t.model_dump()
380391
case ObjPdlType():
381392
if t.obj is None:
382393
d = "obj"

src/pdl/pdl_interpreter.py

Lines changed: 118 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
1212
import warnings
13+
from functools import partial
1314
from os import getenv
1415

1516
warnings.filterwarnings("ignore", "Valid config keys have changed in V2")
@@ -99,6 +100,7 @@
99100
from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402
100101
from .pdl_python_repl import PythonREPL # noqa: E402
101102
from .pdl_scheduler import yield_background, yield_result # noqa: E402
103+
from .pdl_schema_utils import get_json_schema # noqa: E402
102104
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
103105
from .pdl_utils import ( # noqa: E402
104106
GeneratorWrapper,
@@ -341,6 +343,34 @@ def identity(result):
341343
return identity
342344

343345

346+
def set_error_to_scope_for_retry(
347+
scope: ScopeType, error, block_id: Optional[str] = ""
348+
) -> ScopeType:
349+
repeating_same_error = False
350+
pdl_context: Optional[LazyMessages] = scope.get("pdl_context")
351+
if pdl_context is None:
352+
return scope
353+
if pdl_context and isinstance(pdl_context, list):
354+
last_msg = pdl_context[-1]
355+
last_error = last_msg["content"]
356+
if last_error.endswith(error):
357+
repeating_same_error = True
358+
if repeating_same_error:
359+
error = "The previous error occurs multiple times."
360+
err_msg = {
361+
"role": "assistant",
362+
"content": error,
363+
"defsite": block_id,
364+
}
365+
scope = scope | {
366+
"pdl_context": lazy_messages_concat(
367+
pdl_context,
368+
PdlList([err_msg]),
369+
)
370+
}
371+
return scope
372+
373+
344374
def process_advanced_block(
345375
state: InterpreterState,
346376
scope: ScopeType,
@@ -361,52 +391,85 @@ def process_advanced_block(
361391
state = state.with_yield_background(
362392
state.yield_background and context_in_contribute(block)
363393
)
364-
try:
365-
result, background, new_scope, trace = process_block_body(
366-
state, scope, block, loc
367-
)
368-
result = lazy_apply(id_with_set_first_use_nanos(block.pdl__timing), result)
369-
background = lazy_apply(
370-
id_with_set_first_use_nanos(block.pdl__timing), background
371-
)
372-
trace = trace.model_copy(update={"pdl__result": result})
373-
if block.parser is not None:
374-
parser = block.parser
375-
result = lazy_apply(lambda r: parse_result(parser, r), result)
376-
if init_state.yield_result and ContributeTarget.RESULT:
377-
yield_result(result, block.kind)
378-
if block.spec is not None and not isinstance(block, FunctionBlock):
379-
result = lazy_apply(
380-
lambda r: result_with_type_checking(
381-
r, block.spec, "Type errors during spec checking:", loc, trace
382-
),
383-
result,
394+
395+
# Bind result variables here with empty values
396+
result: PdlLazy[Any] = PdlConst(None)
397+
background: LazyMessages = PdlList([{}])
398+
new_scope: ScopeType = PdlDict({})
399+
trace: AdvancedBlockType = EmptyBlock()
400+
401+
max_retry = block.retry if block.retry else 0
402+
trial_total = max_retry + 1
403+
for trial_idx in range(trial_total):
404+
try:
405+
result, background, new_scope, trace = process_block_body(
406+
state, scope, block, loc
384407
)
385-
if block.fallback is not None:
386-
result.result()
387-
except Exception as exc:
388-
if block.fallback is None:
389-
raise exc from exc
390-
(
391-
result,
392-
background,
393-
new_scope,
394-
trace,
395-
) = process_block_of(
396-
block,
397-
"fallback",
398-
state,
399-
scope,
400-
loc=loc,
401-
)
402-
if block.spec is not None and not isinstance(block, FunctionBlock):
403-
loc = append(loc, "fallback")
404-
result = lazy_apply(
405-
lambda r: result_with_type_checking(
406-
r, block.spec, "Type errors during spec checking:", loc, trace
407-
),
408+
result = lazy_apply(id_with_set_first_use_nanos(block.pdl__timing), result)
409+
background = lazy_apply(
410+
id_with_set_first_use_nanos(block.pdl__timing), background
411+
)
412+
trace = trace.model_copy(update={"pdl__result": result})
413+
if block.parser is not None:
414+
# Use partial to create a function with fixed arguments
415+
parser_func = partial(parse_result, block.parser)
416+
result = lazy_apply(parser_func, result)
417+
if init_state.yield_result and ContributeTarget.RESULT:
418+
yield_result(result, block.kind)
419+
if block.spec is not None and not isinstance(block, FunctionBlock):
420+
# Use partial to create a function with fixed arguments
421+
checker = partial(
422+
result_with_type_checking,
423+
spec=block.spec,
424+
msg="Type errors during spec checking:",
425+
loc=loc,
426+
trace=trace,
427+
)
428+
result = lazy_apply(checker, result)
429+
if block.fallback is not None:
430+
result.result()
431+
break
432+
except Exception as exc:
433+
err_msg = exc.args[0]
434+
do_retry = (
435+
block.retry
436+
and trial_idx + 1 < trial_total
437+
and "Keyboard Interrupt" not in err_msg
438+
)
439+
if block.fallback is None and not do_retry:
440+
raise exc from exc
441+
if do_retry:
442+
error = f"An error occurred in a PDL block. Error details: {err_msg}"
443+
print(
444+
f"\n\033[0;31m[Retry {trial_idx+1}/{max_retry}] {error}\033[0m\n",
445+
file=sys.stderr,
446+
)
447+
if block.trace_error_on_retry:
448+
scope = set_error_to_scope_for_retry(scope, error, block.pdl__id)
449+
continue
450+
(
408451
result,
452+
background,
453+
new_scope,
454+
trace,
455+
) = process_block_of(
456+
block,
457+
"fallback",
458+
state,
459+
scope,
460+
loc=loc,
409461
)
462+
if block.spec is not None and not isinstance(block, FunctionBlock):
463+
loc = append(loc, "fallback")
464+
# Use partial to create a function with fixed arguments
465+
checker = partial(
466+
result_with_type_checking,
467+
spec=block.spec,
468+
msg="Type errors during spec checking:",
469+
loc=loc,
470+
trace=trace,
471+
)
472+
result = lazy_apply(checker, result)
410473
if block.def_ is not None:
411474
var = block.def_
412475
new_scope = new_scope | PdlDict({var: result})
@@ -832,6 +895,16 @@ def process_block_body(
832895
if block.def_ is not None:
833896
scope = scope | {block.def_: closure}
834897
closure.pdl__scope = scope
898+
signature: dict[str, Any] = {"type": "function"}
899+
if block.def_ is not None:
900+
signature["name"] = block.def_
901+
if block.description is not None:
902+
signature["description"] = block.description
903+
if block.function is not None:
904+
signature["parameters"] = get_json_schema(block.function, False) or {}
905+
else:
906+
signature["parameters"] = {}
907+
closure.signature = signature
835908
result = PdlConst(closure)
836909
background = PdlList([])
837910
trace = closure.model_copy(update={})
@@ -914,6 +987,8 @@ def process_defs(
914987
state = state.with_iter(idx)
915988
state = state.with_yield_result(False)
916989
state = state.with_yield_background(False)
990+
if isinstance(block, FunctionBlock) and block.def_ is None:
991+
block = block.model_copy(update={"def_": x})
917992
result, _, _, block_trace = process_block(state, scope, block, newloc)
918993
scope = scope | PdlDict({x: result})
919994
defs_trace[x] = block_trace

src/pdl/pdl_llms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def generate_text_stream(
183183

184184

185185
def set_structured_decoding_parameters(
186-
spec: Optional[PdlTypeType],
186+
spec: PdlTypeType,
187187
parameters: Optional[dict[str, Any]],
188188
) -> dict[str, Any]:
189189
if parameters is None:

src/pdl/pdl_schema_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
EnumPdlType,
66
FloatPdlType,
77
IntPdlType,
8+
JsonSchemaTypePdlType,
89
ListPdlType,
910
ListPdlTypeConstraints,
1011
ObjPdlType,
@@ -43,7 +44,7 @@ def convert_to_json_type(a_type):
4344

4445

4546
def pdltype_to_jsonschema(
46-
pdl_type: Optional[PdlTypeType], additional_properties: bool
47+
pdl_type: PdlTypeType, additional_properties: bool
4748
) -> dict[str, Any]:
4849
schema: dict[str, Any]
4950
match pdl_type:
@@ -111,6 +112,12 @@ def pdltype_to_jsonschema(
111112
case OptionalPdlType(optional=t):
112113
t_schema = pdltype_to_jsonschema(t, additional_properties)
113114
schema = {"anyOf": [t_schema, "null"]}
115+
case JsonSchemaTypePdlType(type=t):
116+
if pdl_type.__pydantic_extra__ is None:
117+
extra = {}
118+
else:
119+
extra = pdl_type.__pydantic_extra__
120+
schema = {"type": t, **extra}
114121
case ObjPdlType(obj=pdl_props):
115122
if pdl_props is None:
116123
schema = {"type": "object"}

src/pdl/pdl_schema_validator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99

1010
def type_check_args(
11-
args: Optional[dict[str, Any]], params: Optional[dict[str, Any]], loc
11+
args: Optional[dict[str, Any]],
12+
params: Optional[dict[str, PdlTypeType]],
13+
loc,
1214
) -> list[str]:
1315
if (args == {} or args is None) and (params is None or params == {}):
1416
return []
@@ -35,7 +37,7 @@ def type_check_args(
3537
return type_check(args_copy, schema, loc)
3638

3739

38-
def type_check_spec(result: Any, spec: Optional[PdlTypeType], loc) -> list[str]:
40+
def type_check_spec(result: Any, spec: PdlTypeType, loc) -> list[str]:
3941
schema = pdltype_to_jsonschema(spec, False)
4042
if schema is None:
4143
return ["Error obtaining a valid schema from spec"]

0 commit comments

Comments
 (0)