1010
1111# TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
1212import warnings
13+ from functools import partial
1314from os import getenv
1415
1516warnings .filterwarnings ("ignore" , "Valid config keys have changed in V2" )
99100from .pdl_parser import PDLParseError , parse_file , parse_str # noqa: E402
100101from .pdl_python_repl import PythonREPL # noqa: E402
101102from .pdl_scheduler import yield_background , yield_result # noqa: E402
103+ from .pdl_schema_utils import get_json_schema # noqa: E402
102104from .pdl_schema_validator import type_check_args , type_check_spec # noqa: E402
103105from .pdl_utils import ( # noqa: E402
104106 GeneratorWrapper ,
@@ -341,6 +343,34 @@ def identity(result):
341343 return identity
342344
343345
346+ def set_error_to_scope_for_retry (
347+ scope : ScopeType , error , block_id : Optional [str ] = ""
348+ ) -> ScopeType :
349+ repeating_same_error = False
350+ pdl_context : Optional [LazyMessages ] = scope .get ("pdl_context" )
351+ if pdl_context is None :
352+ return scope
353+ if pdl_context and isinstance (pdl_context , list ):
354+ last_msg = pdl_context [- 1 ]
355+ last_error = last_msg ["content" ]
356+ if last_error .endswith (error ):
357+ repeating_same_error = True
358+ if repeating_same_error :
359+ error = "The previous error occurs multiple times."
360+ err_msg = {
361+ "role" : "assistant" ,
362+ "content" : error ,
363+ "defsite" : block_id ,
364+ }
365+ scope = scope | {
366+ "pdl_context" : lazy_messages_concat (
367+ pdl_context ,
368+ PdlList ([err_msg ]),
369+ )
370+ }
371+ return scope
372+
373+
344374def process_advanced_block (
345375 state : InterpreterState ,
346376 scope : ScopeType ,
@@ -361,52 +391,85 @@ def process_advanced_block(
361391 state = state .with_yield_background (
362392 state .yield_background and context_in_contribute (block )
363393 )
364- try :
365- result , background , new_scope , trace = process_block_body (
366- state , scope , block , loc
367- )
368- result = lazy_apply (id_with_set_first_use_nanos (block .pdl__timing ), result )
369- background = lazy_apply (
370- id_with_set_first_use_nanos (block .pdl__timing ), background
371- )
372- trace = trace .model_copy (update = {"pdl__result" : result })
373- if block .parser is not None :
374- parser = block .parser
375- result = lazy_apply (lambda r : parse_result (parser , r ), result )
376- if init_state .yield_result and ContributeTarget .RESULT :
377- yield_result (result , block .kind )
378- if block .spec is not None and not isinstance (block , FunctionBlock ):
379- result = lazy_apply (
380- lambda r : result_with_type_checking (
381- r , block .spec , "Type errors during spec checking:" , loc , trace
382- ),
383- result ,
394+
395+ # Bind result variables here with empty values
396+ result : PdlLazy [Any ] = PdlConst (None )
397+ background : LazyMessages = PdlList ([{}])
398+ new_scope : ScopeType = PdlDict ({})
399+ trace : AdvancedBlockType = EmptyBlock ()
400+
401+ max_retry = block .retry if block .retry else 0
402+ trial_total = max_retry + 1
403+ for trial_idx in range (trial_total ):
404+ try :
405+ result , background , new_scope , trace = process_block_body (
406+ state , scope , block , loc
384407 )
385- if block .fallback is not None :
386- result .result ()
387- except Exception as exc :
388- if block .fallback is None :
389- raise exc from exc
390- (
391- result ,
392- background ,
393- new_scope ,
394- trace ,
395- ) = process_block_of (
396- block ,
397- "fallback" ,
398- state ,
399- scope ,
400- loc = loc ,
401- )
402- if block .spec is not None and not isinstance (block , FunctionBlock ):
403- loc = append (loc , "fallback" )
404- result = lazy_apply (
405- lambda r : result_with_type_checking (
406- r , block .spec , "Type errors during spec checking:" , loc , trace
407- ),
408+ result = lazy_apply (id_with_set_first_use_nanos (block .pdl__timing ), result )
409+ background = lazy_apply (
410+ id_with_set_first_use_nanos (block .pdl__timing ), background
411+ )
412+ trace = trace .model_copy (update = {"pdl__result" : result })
413+ if block .parser is not None :
414+ # Use partial to create a function with fixed arguments
415+ parser_func = partial (parse_result , block .parser )
416+ result = lazy_apply (parser_func , result )
417+ if init_state .yield_result and ContributeTarget .RESULT :
418+ yield_result (result , block .kind )
419+ if block .spec is not None and not isinstance (block , FunctionBlock ):
420+ # Use partial to create a function with fixed arguments
421+ checker = partial (
422+ result_with_type_checking ,
423+ spec = block .spec ,
424+ msg = "Type errors during spec checking:" ,
425+ loc = loc ,
426+ trace = trace ,
427+ )
428+ result = lazy_apply (checker , result )
429+ if block .fallback is not None :
430+ result .result ()
431+ break
432+ except Exception as exc :
433+ err_msg = exc .args [0 ]
434+ do_retry = (
435+ block .retry
436+ and trial_idx + 1 < trial_total
437+ and "Keyboard Interrupt" not in err_msg
438+ )
439+ if block .fallback is None and not do_retry :
440+ raise exc from exc
441+ if do_retry :
442+ error = f"An error occurred in a PDL block. Error details: { err_msg } "
443+ print (
444+ f"\n \033 [0;31m[Retry { trial_idx + 1 } /{ max_retry } ] { error } \033 [0m\n " ,
445+ file = sys .stderr ,
446+ )
447+ if block .trace_error_on_retry :
448+ scope = set_error_to_scope_for_retry (scope , error , block .pdl__id )
449+ continue
450+ (
408451 result ,
452+ background ,
453+ new_scope ,
454+ trace ,
455+ ) = process_block_of (
456+ block ,
457+ "fallback" ,
458+ state ,
459+ scope ,
460+ loc = loc ,
409461 )
462+ if block .spec is not None and not isinstance (block , FunctionBlock ):
463+ loc = append (loc , "fallback" )
464+ # Use partial to create a function with fixed arguments
465+ checker = partial (
466+ result_with_type_checking ,
467+ spec = block .spec ,
468+ msg = "Type errors during spec checking:" ,
469+ loc = loc ,
470+ trace = trace ,
471+ )
472+ result = lazy_apply (checker , result )
410473 if block .def_ is not None :
411474 var = block .def_
412475 new_scope = new_scope | PdlDict ({var : result })
@@ -832,6 +895,16 @@ def process_block_body(
832895 if block .def_ is not None :
833896 scope = scope | {block .def_ : closure }
834897 closure .pdl__scope = scope
898+ signature : dict [str , Any ] = {"type" : "function" }
899+ if block .def_ is not None :
900+ signature ["name" ] = block .def_
901+ if block .description is not None :
902+ signature ["description" ] = block .description
903+ if block .function is not None :
904+ signature ["parameters" ] = get_json_schema (block .function , False ) or {}
905+ else :
906+ signature ["parameters" ] = {}
907+ closure .signature = signature
835908 result = PdlConst (closure )
836909 background = PdlList ([])
837910 trace = closure .model_copy (update = {})
@@ -914,6 +987,8 @@ def process_defs(
914987 state = state .with_iter (idx )
915988 state = state .with_yield_result (False )
916989 state = state .with_yield_background (False )
990+ if isinstance (block , FunctionBlock ) and block .def_ is None :
991+ block = block .model_copy (update = {"def_" : x })
917992 result , _ , _ , block_trace = process_block (state , scope , block , newloc )
918993 scope = scope | PdlDict ({x : result })
919994 defs_trace [x ] = block_trace
0 commit comments