@@ -355,6 +355,7 @@ def _infer_schema_unions() -> None:
355355 "ChatMessagePartToolCallRequestDataDict" : "ToolCallRequestDataDict" ,
356356 "ChatMessagePartToolCallResultData" : "ToolCallResultData" ,
357357 "ChatMessagePartToolCallResultDataDict" : "ToolCallResultDataDict" ,
358+ "FunctionToolCallRequest" : "ToolCallRequest" ,
358359 "FunctionToolCallRequestDict" : "ToolCallRequestDict" ,
359360 # Prettier channel creation type names
360361 "LlmChannelPredictCreationParameter" : "PredictionChannelRequest" ,
@@ -431,7 +432,9 @@ def _generate_data_model_from_json_schema() -> None:
431432 if override_name is not None :
432433 name_constant .value = override_name
433434 # Scan top level nodes only (allows for adding & removing top level nodes)
434- for node in model_ast .body :
435+ declared_structs : set [str ] = set ()
436+ additional_nodes : list [tuple [int , ast .stmt ]] = []
437+ for body_idx , node in enumerate (model_ast .body ):
435438 match node :
436439 case ast .ClassDef (name = name ):
437440 # Override names when defining classes
@@ -440,8 +443,11 @@ def _generate_data_model_from_json_schema() -> None:
440443 generated_name = name
441444 name = node .name = override_name
442445 exported_names .append (name )
443- if name .endswith ("Dict" ):
446+ if not name .endswith ("Dict" ):
447+ declared_structs .add (name )
448+ else :
444449 struct_name = name .removesuffix ("Dict" )
450+ assert struct_name in declared_structs , struct_name
445451 dict_token_replacements [struct_name ] = name
446452 if override_name is not None :
447453 # Fix up docstring reference back to corresponding struct type
@@ -454,7 +460,9 @@ def _generate_data_model_from_json_schema() -> None:
454460 docstring_node .value = docstring .replace (generated_name , name )
455461 case ast .Assign (targets = [ast .Name (id = alias )], value = expr ):
456462 match expr :
457- # For dict fields, replace builtin type aliases with the builtin type names
463+ # For dict fields, replace all type aliases with the original type name
464+ # This covers both builtin type aliases (as these will be accepted),
465+ # and struct type aliases (for mapping to their TypedDict counterparts)
458466 case (
459467 # alias = name
460468 ast .Name (id = name )
@@ -465,59 +473,85 @@ def _generate_data_model_from_json_schema() -> None:
465473 )
466474 ):
467475 if hasattr (builtins , name ):
476+ # Simple alias for builtins
468477 dict_token_replacements [alias ] = name
478+ else :
479+ dict_name = dict_token_replacements .get (name , None )
480+ if dict_name is not None :
481+ dict_token_replacements [alias ] = dict_name
482+ # Unions require additional handling to add dict variants of the union
483+ case ast .BinOp (op = ast .BitOr ()) as union_node :
484+ named_union_members : list [str ] = []
485+ other_union_members : list [ast .expr ] = []
486+ optional_union = False
487+ needs_dict_alias = False
488+ for union_child in ast .walk (union_node ):
489+ match union_child :
490+ case ast .Name (id = name ):
491+ named_union_members .append (name )
492+ if not needs_dict_alias :
493+ needs_dict_alias = (
494+ name in dict_token_replacements
495+ )
496+ case ast .Subscript (value = ast .Name (id = "Mapping" )):
497+ other_union_members .append (union_child )
498+ case ast .Constant (value = None ):
499+ optional_union = True
500+ # Ignore expected structural elements
501+ case (
502+ ast .BinOp (op = ast .BitOr ())
503+ | ast .BitOr ()
504+ | ast .Load ()
505+ | ast .Store ()
506+ | ast .Tuple (
507+ elts = [ast .Name (id = "str" ), ast .Name (id = "str" )]
508+ )
509+ ):
510+ continue
511+ case _:
512+ raise RuntimeError (
513+ f"Failed to parse union node: { ast .dump (union_child )} in { ast .dump (node )} "
514+ )
515+ if needs_dict_alias :
516+ dict_alias = f"{ alias } Dict"
517+ dict_token_replacements [alias ] = dict_alias
518+ struct_union_member = named_union_members [0 ]
519+ dict_union_member = dict_token_replacements .get (
520+ struct_union_member , struct_union_member
521+ )
522+ dict_union : ast .expr = ast .Name (
523+ dict_union_member , ast .Load ()
524+ )
525+ for struct_union_member in named_union_members [1 :]:
526+ dict_union_member = dict_token_replacements .get (
527+ struct_union_member , struct_union_member
528+ )
529+ union_rhs = ast .Name (dict_union_member , ast .Load ())
530+ dict_union = ast .BinOp (
531+ dict_union , ast .BitOr (), union_rhs
532+ )
533+ for other_union_member in other_union_members :
534+ dict_union = ast .BinOp (
535+ dict_union , ast .BitOr (), other_union_member
536+ )
537+ if optional_union :
538+ dict_union = ast .BinOp (
539+ dict_union , ast .BitOr (), ast .Constant (None )
540+ )
541+ # Insert the dict alias assignment after the struct alias assignment
542+ dict_alias_target = ast .Name (dict_alias , ast .Store ())
543+ dict_alias_node = ast .Assign (
544+ [dict_alias_target ], dict_union
545+ )
546+ additional_nodes .append ((body_idx + 1 , dict_alias_node ))
547+
469548 # Write any AST level changes back to the source file
470- # TODO: Move more changes to the AST rather than relying on raw text replacement
549+ for insertion_idx , node in reversed (additional_nodes ):
550+ model_ast .body [insertion_idx :insertion_idx ] = (node ,)
551+ ast .fix_missing_locations (model_ast )
471552 _MODEL_PATH .write_text (ast .unparse (model_ast ))
472- # Additional type union names to be translated
473- # Inject the dict versions of required type unions
474- # (This is a brute force hack, but it's good enough while there's only a few that matter)
475- _single_line_union = (" = " , " | " , "" )
476- # _multi_line_union = (" = (\n ", "\n | ", "\n)")
477- _dict_unions = (
478- (
479- "AnyChatMessage" ,
480- (
481- "AssistantResponse" ,
482- "UserMessage" ,
483- "SystemPrompt" ,
484- "ToolResultMessage" ,
485- ),
486- _single_line_union ,
487- ),
488- (
489- "LlmToolUseSetting" ,
490- ("LlmToolUseSettingNone" , "LlmToolUseSettingToolArray" ),
491- _single_line_union ,
492- ),
493- (
494- "ModelSpecifier" ,
495- ("ModelSpecifierQuery" , "ModelSpecifierInstanceReference" ),
496- _single_line_union ,
497- ),
498- )
499- combined_union_defs : dict [str , str ] = {}
500- for union_name , union_members , (assign_sep , union_sep , union_end ) in _dict_unions :
501- dict_union_name = f"{ union_name } Dict"
502- dict_token_replacements [union_name ] = dict_union_name
503- if dict_union_name != f"{ union_name } Dict" :
504- raise RuntimeError (
505- f"Union { union_name !r} mapped to unexpected name { dict_union_name !r} "
506- )
507- union_def = (
508- f"{ union_name } { assign_sep } { union_sep .join (union_members )} { union_end } "
509- )
510- dict_union_def = f"{ dict_union_name } { assign_sep } { ('Dict' + union_sep ).join (union_members )} Dict{ union_end } "
511- combined_union_defs [union_def ] = f"{ union_def } \n { dict_union_def } "
512- # Additional type aliases for translation
513- # TODO: Rather than setting these on an ad hoc basis, record all the pure aliases
514- # during the AST scan, and add the extra dict token replacements automatically
515- dict_token_replacements ["PromptTemplate" ] = "LlmPromptTemplateDict"
516- dict_token_replacements ["ReasoningParsing" ] = "LlmReasoningParsingDict"
517- dict_token_replacements ["RawTools" ] = "LlmToolUseSettingDict"
518- dict_token_replacements ["LlmTool" ] = "LlmToolFunctionDict"
519- dict_token_replacements ["LlmToolParameters" ] = "LlmToolParametersObjectDict"
520553 # Replace struct names in TypedDict definitions with their dict counterparts
554+ # Also replace other type alias names with the original type (as dict inputs will be translated as needed)
521555 model_tokens = tokenize .tokenize (_MODEL_PATH .open ("rb" ).readline )
522556 updated_tokens : list [tokenize .TokenInfo ] = []
523557 checking_class_header = False
@@ -529,7 +563,7 @@ def _generate_data_model_from_json_schema() -> None:
529563 assert token_type == tokenize .NAME
530564 if token .endswith ("Dict" ):
531565 processing_typed_dict = True
532- # Either way, not checking the class header anymore
566+ # Either way, not checking the class header any more
533567 checking_class_header = False
534568 elif processing_typed_dict :
535569 # Stop processing at the next dedent (no methods in the typed dicts)
@@ -545,9 +579,6 @@ def _generate_data_model_from_json_schema() -> None:
545579 checking_class_header = True
546580 updated_tokens .append (token_info )
547581 updated_source : str = tokenize .untokenize (updated_tokens ).decode ("utf-8" )
548- # Inject the dict versions of required type unions
549- for union_def , combined_def in combined_union_defs .items ():
550- updated_source = updated_source .replace (union_def , combined_def )
551582 # Insert __all__ between the imports and the schema definitions
552583 name_lines = (f' "{ name } ",' for name in (sorted (exported_names )))
553584 lines_to_insert = ["__all__ = [" , * name_lines , "]" , "" , "" ]
0 commit comments