|
5 | 5 |
|
6 | 6 | from ariadne_codegen.plugins.base import Plugin |
7 | 7 |
|
8 | | -from .return_type import GraphQLReturnTypeModel |
9 | | - |
10 | 8 | if TYPE_CHECKING: |
11 | 9 | from graphql import ExecutableDefinitionNode |
12 | 10 |
|
13 | 11 |
|
14 | | -class InfrahubPlugin(Plugin): |
| 12 | +class FutureAnnotationPlugin(Plugin): |
15 | 13 | @staticmethod |
16 | | - def find_base_model_index(module: ast.Module) -> int: |
17 | | - for idx, item in enumerate(module.body): |
18 | | - if isinstance(item, ast.ImportFrom) and item.module == "base_model": |
19 | | - return idx |
20 | | - return -1 |
| 14 | + def insert_future_annotation(module: ast.Module) -> ast.Module: |
| 15 | + # First check if the future annotation is already present |
| 16 | + for item in module.body: |
| 17 | + if isinstance(item, ast.ImportFrom) and item.module == "__future__" and item.names[0].name == "annotations": |
| 18 | + return module |
21 | 19 |
|
22 | | - @classmethod |
23 | | - def replace_base_model_import(cls, module: ast.Module) -> ast.Module: |
24 | | - base_model_index = cls.find_base_model_index(module) |
25 | | - if base_model_index == -1: |
26 | | - raise ValueError("BaseModel not found in module") |
27 | | - module.body[base_model_index] = ast.ImportFrom( |
28 | | - module="infrahub_sdk.graphql", names=[ast.alias(name=GraphQLReturnTypeModel.__name__)], level=2 |
29 | | - ) |
| 20 | + module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0)) |
30 | 21 | return module |
31 | 22 |
|
32 | | - @staticmethod |
33 | | - def replace_base_model_class(module: ast.Module) -> ast.Module: |
34 | | - """Replace the BaseModel inserted by Ariadne with the GraphQLReturnTypeModel class.""" |
35 | | - for item in module.body: |
36 | | - if not isinstance(item, ast.ClassDef): |
37 | | - continue |
| 23 | + def generate_result_types_module( |
| 24 | + self, |
| 25 | + module: ast.Module, |
| 26 | + operation_definition: ExecutableDefinitionNode, # noqa: ARG002 |
| 27 | + ) -> ast.Module: |
| 28 | + module = self.insert_future_annotation(module) |
38 | 29 |
|
39 | | - for base in item.bases: |
40 | | - if isinstance(base, ast.Name) and base.id == "BaseModel": |
41 | | - base.id = GraphQLReturnTypeModel.__name__ |
42 | 30 | return module |
43 | 31 |
|
44 | | - @staticmethod |
45 | | - def insert_future_annotation(module: ast.Module) -> ast.Module: |
46 | | - """Insert the future annotation at the beginning of the module.""" |
47 | | - module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0)) |
48 | | - return module |
49 | 32 |
|
| 33 | +class StandardTypeHintPlugin(Plugin): |
50 | 34 | @classmethod |
51 | 35 | def replace_list_in_subscript(cls, subscript: ast.Subscript) -> ast.Subscript: |
52 | 36 | if isinstance(subscript.value, ast.Name) and subscript.value.id == "List": |
@@ -76,10 +60,33 @@ def generate_result_types_module( |
76 | 60 | module: ast.Module, |
77 | 61 | operation_definition: ExecutableDefinitionNode, # noqa: ARG002 |
78 | 62 | ) -> ast.Module: |
79 | | - module = self.insert_future_annotation(module) |
80 | | - module = self.replace_base_model_import(module) |
81 | | - module = self.replace_base_model_class(module) |
82 | | - |
| 63 | + module = FutureAnnotationPlugin.insert_future_annotation(module) |
83 | 64 | module = self.replace_list_annotations(module) |
84 | 65 |
|
85 | 66 | return module |
| 67 | + |
| 68 | + |
| 69 | +class PydanticBaseModelPlugin(Plugin): |
| 70 | + @staticmethod |
| 71 | + def find_base_model_index(module: ast.Module) -> int: |
| 72 | + for idx, item in enumerate(module.body): |
| 73 | + if isinstance(item, ast.ImportFrom) and item.module == "base_model": |
| 74 | + return idx |
| 75 | + return -1 |
| 76 | + |
| 77 | + @classmethod |
| 78 | + def replace_base_model_import(cls, module: ast.Module) -> ast.Module: |
| 79 | + base_model_index = cls.find_base_model_index(module) |
| 80 | + if base_model_index == -1: |
| 81 | + raise ValueError("BaseModel not found in module") |
| 82 | + module.body[base_model_index] = ast.ImportFrom(module="pydantic", names=[ast.alias(name="BaseModel")], level=0) |
| 83 | + return module |
| 84 | + |
| 85 | + def generate_result_types_module( |
| 86 | + self, |
| 87 | + module: ast.Module, |
| 88 | + operation_definition: ExecutableDefinitionNode, # noqa: ARG002 |
| 89 | + ) -> ast.Module: |
| 90 | + module = self.replace_base_model_import(module) |
| 91 | + |
| 92 | + return module |
0 commit comments