@@ -25,7 +25,7 @@ def replace_base_model_import(cls, module: ast.Module) -> ast.Module:
2525 if base_model_index == - 1 :
2626 raise ValueError ("BaseModel not found in module" )
2727 module .body [base_model_index ] = ast .ImportFrom (
28- module = "infrahub_sdk.graphql" , names = [ast .alias (name = GraphQLReturnTypeModel .__name__ )]
28+ module = "infrahub_sdk.graphql" , names = [ast .alias (name = GraphQLReturnTypeModel .__name__ )], level = 2
2929 )
3030 return module
3131
@@ -41,9 +41,34 @@ def replace_base_model_class(module: ast.Module) -> ast.Module:
4141 base .id = GraphQLReturnTypeModel .__name__
4242 return module
4343
44- def insert_future_annotation (self , module : ast .Module ) -> ast .Module :
44+ @staticmethod
45+ def insert_future_annotation (module : ast .Module ) -> ast .Module :
4546 """Insert the future annotation at the beginning of the module."""
46- module .body .insert (0 , ast .ImportFrom (module = "__future__" , names = [ast .alias (name = "annotations" )]))
47+ module .body .insert (0 , ast .ImportFrom (module = "__future__" , names = [ast .alias (name = "annotations" )], level = 0 ))
48+ return module
49+
50+ @classmethod
51+ def replace_list_in_subscript (cls , subscript : ast .Subscript ) -> ast .Subscript :
52+ if isinstance (subscript .value , ast .Name ) and subscript .value .id == "List" :
53+ subscript .value .id = "list"
54+ if isinstance (subscript .slice , ast .Subscript ):
55+ subscript .slice = cls .replace_list_in_subscript (subscript .slice )
56+
57+ return subscript
58+
59+ @classmethod
60+ def replace_list_annotations (cls , module : ast .Module ) -> ast .Module :
61+ for item in module .body :
62+ if not isinstance (item , ast .ClassDef ):
63+ continue
64+
65+ # replace List with list in the annotations when list is used as a type
66+ for class_item in item .body :
67+ if not isinstance (class_item , ast .AnnAssign ):
68+ continue
69+ if isinstance (class_item .annotation , ast .Subscript ):
70+ class_item .annotation = cls .replace_list_in_subscript (class_item .annotation )
71+
4772 return module
4873
4974 def generate_result_types_module (
@@ -55,4 +80,6 @@ def generate_result_types_module(
5580 module = self .replace_base_model_import (module )
5681 module = self .replace_base_model_class (module )
5782
83+ module = self .replace_list_annotations (module )
84+
5885 return module
0 commit comments