@@ -511,15 +511,38 @@ def _generate_import_statement(
511511 Returns:
512512 A formatted, multi-line import statement string.
513513 """
514+
514515 names = sorted (list (set ([item [key ] for item in context ])))
515516 names_str = ",\n " .join (names )
516517 return f"from { package } import (\n { names_str } \n )"
517518
518519
520+ def _get_request_class_name (method_name : str , config : Dict [str , Any ]) -> str :
521+ """Gets the inferred request class name, applying overrides from config."""
522+ inferred_request_name = name_utils .method_to_request_class_name (method_name )
523+ method_overrides = config .get ("filter" , {}).get ("methods" , {}).get ("overrides" , {})
524+ if method_name in method_overrides :
525+ return method_overrides [method_name ].get (
526+ "request_class_name" , inferred_request_name
527+ )
528+ return inferred_request_name
529+
530+
531+ def _find_fq_request_name (
532+ request_name : str , request_arg_schema : Dict [str , List [str ]]
533+ ) -> str :
534+ """Finds the fully qualified request name in the schema."""
535+ for key in request_arg_schema .keys ():
536+ if key .endswith (f".{ request_name } " ):
537+ return key
538+ return ""
539+
540+
519541def generate_code (config : Dict [str , Any ], analysis_results : tuple ) -> None :
520542 """
521543 Generates source code files using Jinja2 templates.
522544 """
545+
523546 data , all_imports , all_types , request_arg_schema = analysis_results
524547 project_root = config ["project_root" ]
525548 config_dir = config ["config_dir" ]
@@ -539,27 +562,11 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None:
539562 "return_type" : method_info ["return_type" ],
540563 }
541564
542- # Infer the request class and find its schema.
543- inferred_request_name = name_utils .method_to_request_class_name (
544- method_name
545- )
546-
547- # Check for a request class name override in the config.
548- method_overrides = (
549- config .get ("filter" , {}).get ("methods" , {}).get ("overrides" , {})
565+ request_name = _get_request_class_name (method_name , config )
566+ fq_request_name = _find_fq_request_name (
567+ request_name , request_arg_schema
550568 )
551- if method_name in method_overrides :
552- inferred_request_name = method_overrides [method_name ].get (
553- "request_class_name" , inferred_request_name
554- )
555-
556- fq_request_name = ""
557- for key in request_arg_schema .keys ():
558- if key .endswith (f".{ inferred_request_name } " ):
559- fq_request_name = key
560- break
561569
562- # If found, augment the method context.
563570 if fq_request_name :
564571 context ["request_class_full_name" ] = fq_request_name
565572 context ["request_id_args" ] = request_arg_schema [fq_request_name ]
0 commit comments