@@ -334,6 +334,36 @@ def _infer_schema_unions() -> None:
334334 _INFERRED_SCHEMA_PATH .write_text (json .dumps (processed_schema , indent = 2 ))
335335
336336
337+ # Unfortunately, "aliases" in the code generator isn't full type renaming
338+ # Instead, these are handled as part of the AST transformation step
339+ _DATA_MODEL_NAME_OVERRIDES = {
340+ # Prettier chat history type names
341+ "ChatMessageData" : "AnyChatMessage" ,
342+ "ChatMessageDataUser" : "UserMessage" ,
343+ "ChatMessageDataSystem" : "SystemPrompt" ,
344+ "ChatMessageDataAssistant" : "AssistantResponse" ,
345+ "ChatMessageDataTool" : "ToolResultMessage" ,
346+ "ChatMessageDataUserDict" : "UserMessageDict" ,
347+ "ChatMessageDataSystemDict" : "SystemPromptDict" ,
348+ "ChatMessageDataAssistantDict" : "AssistantResponseDict" ,
349+ "ChatMessageDataToolDict" : "ToolResultMessageDict" ,
350+ "ChatMessagePartFileData" : "FileHandle" ,
351+ "ChatMessagePartFileDataDict" : "FileHandleDict" ,
352+ "ChatMessagePartTextData" : "TextData" ,
353+ "ChatMessagePartTextDataDict" : "TextDataDict" ,
354+ "ChatMessagePartToolCallRequestData" : "ToolCallRequestData" ,
355+ "ChatMessagePartToolCallRequestDataDict" : "ToolCallRequestDataDict" ,
356+ "ChatMessagePartToolCallResultData" : "ToolCallResultData" ,
357+ "ChatMessagePartToolCallResultDataDict" : "ToolCallResultDataDict" ,
358+ "FunctionToolCallRequestDict" : "ToolCallRequestDict" ,
359+ # Prettier channel creation type names
360+ "LlmChannelPredictCreationParameter" : "PredictionChannelRequest" ,
361+ "LlmChannelPredictCreationParameterDict" : "PredictionChannelRequestDict" ,
362+ "RepositoryChannelDownloadModelCreationParameter" : "DownloadModelChannelRequest" ,
363+ "RepositoryChannelDownloadModelCreationParameterDict" : "DownloadModelChannelRequestDict" ,
364+ }
365+
366+
337367def _generate_data_model_from_json_schema () -> None :
338368 """Produce Python data model classes from the exported JSON schema file."""
339369 if not _CACHED_SCHEMA_PATH .exists ():
@@ -387,42 +417,73 @@ def _generate_data_model_from_json_schema() -> None:
387417 model_ast = ast .parse (model_source )
388418 dict_token_replacements : dict [str , str ] = {}
389419 exported_names : list [str ] = []
420+ # Scan all nodes in the AST (only in-place node changes are valid here)
421+ for node in ast .walk (model_ast ):
422+ match node :
423+ case ast .Name (id = name ) as name_node :
424+ # Override names when looked up or assigned directly
425+ override_name = _DATA_MODEL_NAME_OVERRIDES .get (name , None )
426+ if override_name is not None :
427+ name_node .id = override_name
428+ case ast .Constant (value = name ) as name_constant :
429+ # Override names when they appear as type hint forward references
430+ override_name = _DATA_MODEL_NAME_OVERRIDES .get (name , None )
431+ if override_name is not None :
432+ name_constant .value = override_name
433+ # Scan top level nodes only (allows for adding & removing top level nodes)
390434 for node in model_ast .body :
391435 match node :
392436 case ast .ClassDef (name = name ):
393- name = node .name
437+ # Override names when defining classes
438+ override_name = _DATA_MODEL_NAME_OVERRIDES .get (name , None )
439+ if override_name is not None :
440+ generated_name = name
441+ name = node .name = override_name
394442 exported_names .append (name )
395443 if name .endswith ("Dict" ):
396444 struct_name = name .removesuffix ("Dict" )
397445 dict_token_replacements [struct_name ] = name
446+ if override_name is not None :
447+ # Fix up docstring reference back to corresponding struct type
448+ expr_node = node .body [0 ]
449+ assert isinstance (expr_node , ast .Expr )
450+ docstring_node = expr_node .value
451+ assert isinstance (docstring_node , ast .Constant )
452+ docstring = docstring_node .value
453+ assert isinstance (docstring , str )
454+ docstring_node .value = docstring .replace (generated_name , name )
398455 case ast .Assign (targets = [ast .Name (id = alias )], value = expr ):
399- # We don't want to require the specific aliased types for dict inputs
400456 match expr :
457+ # For dict fields, replace builtin type aliases with the builtin type names
401458 case (
459+ # alias = name
402460 ast .Name (id = name )
461+ # alias = Annotated[name, ...]
403462 | ast .Subscript (
404463 value = ast .Name (id = "Annotated" ),
405464 slice = ast .Tuple (elts = [ast .Name (id = name ), * _]),
406465 )
407466 ):
408467 if hasattr (builtins , name ):
409468 dict_token_replacements [alias ] = name
410-
469+ # Write any AST level changes back to the source file
470+ # TODO: Move more changes to the AST rather than relying on raw text replacement
471+ _MODEL_PATH .write_text (ast .unparse (model_ast ))
411472 # Additional type union names to be translated
412473 # Inject the dict versions of required type unions
413474 # (This is a brute force hack, but it's good enough while there's only a few that matter)
414475 _single_line_union = (" = " , " | " , "" )
415- _multi_line_union = (" = (\n " , "\n | " , "\n )" )
476+ # _multi_line_union = (" = (\n ", "\n | ", "\n)")
416477 _dict_unions = (
417478 (
418- "ChatMessageData " ,
479+ "AnyChatMessage " ,
419480 (
420- "ChatMessageDataAssistant " ,
421- "ChatMessageDataUser " ,
422- "ChatMessageDataSystem " ,
423- "ChatMessageDataTool " ,
481+ "AssistantResponse " ,
482+ "UserMessage " ,
483+ "SystemPrompt " ,
484+ "ToolResultMessage " ,
424485 ),
425- _multi_line_union ,
486+ _single_line_union ,
426487 ),
427488 (
428489 "LlmToolUseSetting" ,
0 commit comments