Skip to content

Commit f27fe6c

Browse files
committed
add infrahubctl graphql generate-return-types command
1 parent ac60214 commit f27fe6c

File tree

11 files changed

+492
-168
lines changed

11 files changed

+492
-168
lines changed

infrahub_sdk/ctl/cli_commands.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..ctl.client import initialize_client, initialize_client_sync
2727
from ..ctl.exceptions import QueryNotFoundError
2828
from ..ctl.generator import run as run_generator
29+
from ..ctl.graphql import app as graphql_app
2930
from ..ctl.menu import app as menu_app
3031
from ..ctl.object import app as object_app
3132
from ..ctl.render import list_jinja2_transforms, print_template_errors
@@ -63,6 +64,7 @@
6364
app.add_typer(repository_app, name="repository")
6465
app.add_typer(menu_app, name="menu")
6566
app.add_typer(object_app, name="object")
67+
app.add_typer(graphql_app, name="graphql")
6668

6769
app.command(name="dump")(dump)
6870
app.command(name="load")(load)

infrahub_sdk/ctl/graphql.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
from pathlib import Path
5+
6+
import typer
7+
from ariadne_codegen.client_generators.package import get_package_generator
8+
from ariadne_codegen.plugins.explorer import get_plugins_types
9+
from ariadne_codegen.plugins.manager import PluginManager
10+
from ariadne_codegen.schema import (
11+
filter_operations_definitions,
12+
get_graphql_queries,
13+
get_graphql_schema_from_path,
14+
)
15+
from ariadne_codegen.settings import ClientSettings, CommentsStrategy
16+
from rich.console import Console
17+
18+
from ..async_typer import AsyncTyper
19+
from ..ctl.utils import catch_exception
20+
from .parameters import CONFIG_PARAM
21+
22+
app = AsyncTyper()
23+
console = Console()
24+
25+
ARIADNE_PLUGINS = ["infrahub_sdk.graphql.plugin.InfrahubPlugin"]
26+
27+
28+
def find_gql_files(query_path: Path) -> list[Path]:
29+
"""
30+
Find all files with .gql extension in the specified directory.
31+
32+
Args:
33+
query_path: Path to the directory to search for .gql files
34+
35+
Returns:
36+
List of Path objects for all .gql files found
37+
"""
38+
if not query_path.exists():
39+
raise FileNotFoundError(f"Directory not found: {query_path}")
40+
41+
if not query_path.is_dir() and query_path.is_file():
42+
yield query_path
43+
44+
else:
45+
yield from query_path.glob("*/*.gql")
46+
47+
48+
@app.callback()
49+
def callback() -> None:
50+
"""
51+
Various GraphQL related commands.
52+
"""
53+
54+
55+
@app.command()
56+
@catch_exception(console=console)
57+
async def generate_return_types(
58+
query: Path = typer.Argument(Path.cwd(), help="Location of the GraphQL query file(s)."),
59+
schema: Path = typer.Option("schema.graphql", help="Path to the GraphQL schema file."),
60+
_: str = CONFIG_PARAM,
61+
) -> None:
62+
"""Create Pydantic Models for GraphQL query return types"""
63+
64+
# Load the GraphQL schema
65+
if not schema.exists():
66+
raise FileNotFoundError(f"GraphQL Schema file not found: {schema}")
67+
graphql_schema = get_graphql_schema_from_path(schema_path=str(schema))
68+
69+
# Initialize the plugin manager
70+
plugin_manager = PluginManager(
71+
schema=graphql_schema,
72+
plugins_types=get_plugins_types(plugins_strs=ARIADNE_PLUGINS),
73+
)
74+
75+
# Find the GraphQL files and organize them by directory
76+
gql_files = find_gql_files(query)
77+
gql_per_directory: dict[Path, list[Path]] = defaultdict(list)
78+
for gql_file in gql_files:
79+
gql_per_directory[gql_file.parent].append(gql_file)
80+
81+
# Generate the Pydantic Models for the GraphQL queries
82+
for directory in gql_per_directory.keys():
83+
package_generator = get_package_generator(
84+
schema=graphql_schema,
85+
fragments=[],
86+
settings=ClientSettings(
87+
schema_path=str(schema),
88+
target_package_name=directory.name,
89+
queries_path=str(directory),
90+
include_comments=CommentsStrategy.NONE,
91+
),
92+
plugin_manager=plugin_manager,
93+
)
94+
95+
definitions = get_graphql_queries(queries_path=str(directory), schema=graphql_schema)
96+
queries = filter_operations_definitions(definitions)
97+
98+
for query_operation in queries:
99+
package_generator.add_operation(query_operation)
100+
101+
package_generator._generate_result_types()
102+
103+
for file_name in package_generator._result_types_files.keys():
104+
print(f"Generated {file_name} in {directory}")

infrahub_sdk/graphql/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .constants import VARIABLE_TYPE_MAPPING
2+
from .query import Mutation, Query
3+
from .renderers import render_input_block, render_query_block, render_variables_to_string
4+
from .return_type import GraphQLReturnTypeModel
5+
6+
__all__ = [
7+
"VARIABLE_TYPE_MAPPING",
8+
"GraphQLReturnTypeModel",
9+
"Mutation",
10+
"Query",
11+
"render_input_block",
12+
"render_query_block",
13+
"render_variables_to_string",
14+
]

infrahub_sdk/graphql/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARIABLE_TYPE_MAPPING = ((str, "String!"), (int, "Int!"), (float, "Float!"), (bool, "Boolean!"))

infrahub_sdk/graphql/plugin.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import TYPE_CHECKING
5+
6+
from ariadne_codegen.plugins.base import Plugin
7+
8+
from .return_type import GraphQLReturnTypeModel
9+
10+
if TYPE_CHECKING:
11+
from graphql import ExecutableDefinitionNode
12+
13+
14+
class InfrahubPlugin(Plugin):
15+
@staticmethod
16+
def find_base_model_index(module: ast.Module) -> int:
17+
for idx, item in enumerate(module.body):
18+
if isinstance(item, ast.ImportFrom) and item.module == "base_model":
19+
return idx
20+
return -1
21+
22+
@classmethod
23+
def replace_base_model_import(cls, module: ast.Module) -> ast.Module:
24+
base_model_index = cls.find_base_model_index(module)
25+
if base_model_index == -1:
26+
raise ValueError("BaseModel not found in module")
27+
module.body[base_model_index] = ast.ImportFrom(
28+
module="infrahub_sdk.graphql", names=[ast.alias(name=GraphQLReturnTypeModel.__name__)]
29+
)
30+
return module
31+
32+
@staticmethod
33+
def replace_base_model_class(module: ast.Module) -> ast.Module:
34+
"""Replace the BaseModel inserted by Ariadne with the GraphQLReturnTypeModel class."""
35+
for item in module.body:
36+
if not isinstance(item, ast.ClassDef):
37+
continue
38+
39+
for base in item.bases:
40+
if isinstance(base, ast.Name) and base.id == "BaseModel":
41+
base.id = GraphQLReturnTypeModel.__name__
42+
return module
43+
44+
def insert_future_annotation(self, module: ast.Module) -> ast.Module:
45+
"""Insert the future annotation at the beginning of the module."""
46+
module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")]))
47+
return module
48+
49+
def generate_result_types_module(
50+
self,
51+
module: ast.Module,
52+
operation_definition: ExecutableDefinitionNode, # noqa: ARG002
53+
) -> ast.Module:
54+
module = self.insert_future_annotation(module)
55+
module = self.replace_base_model_import(module)
56+
module = self.replace_base_model_class(module)
57+
58+
return module

infrahub_sdk/graphql/query.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from .renderers import render_input_block, render_query_block, render_variables_to_string
6+
7+
8+
class BaseGraphQLQuery:
9+
query_type: str = "not-defined"
10+
indentation: int = 4
11+
12+
def __init__(self, query: dict, variables: dict | None = None, name: str | None = None):
13+
self.query = query
14+
self.variables = variables
15+
self.name = name or ""
16+
17+
def render_first_line(self) -> str:
18+
first_line = self.query_type
19+
20+
if self.name:
21+
first_line += " " + self.name
22+
23+
if self.variables:
24+
first_line += f" ({render_variables_to_string(self.variables)})"
25+
26+
first_line += " {"
27+
28+
return first_line
29+
30+
31+
class Query(BaseGraphQLQuery):
32+
query_type = "query"
33+
34+
def render(self, convert_enum: bool = False) -> str:
35+
lines = [self.render_first_line()]
36+
lines.extend(
37+
render_query_block(
38+
data=self.query, indentation=self.indentation, offset=self.indentation, convert_enum=convert_enum
39+
)
40+
)
41+
lines.append("}")
42+
43+
return "\n" + "\n".join(lines) + "\n"
44+
45+
46+
class Mutation(BaseGraphQLQuery):
47+
query_type = "mutation"
48+
49+
def __init__(self, *args: Any, mutation: str, input_data: dict, **kwargs: Any):
50+
self.input_data = input_data
51+
self.mutation = mutation
52+
super().__init__(*args, **kwargs)
53+
54+
def render(self, convert_enum: bool = False) -> str:
55+
lines = [self.render_first_line()]
56+
lines.append(" " * self.indentation + f"{self.mutation}(")
57+
lines.extend(
58+
render_input_block(
59+
data=self.input_data,
60+
indentation=self.indentation,
61+
offset=self.indentation * 2,
62+
convert_enum=convert_enum,
63+
)
64+
)
65+
lines.append(" " * self.indentation + "){")
66+
lines.extend(
67+
render_query_block(
68+
data=self.query,
69+
indentation=self.indentation,
70+
offset=self.indentation * 2,
71+
convert_enum=convert_enum,
72+
)
73+
)
74+
lines.append(" " * self.indentation + "}")
75+
lines.append("}")
76+
77+
return "\n" + "\n".join(lines) + "\n"

0 commit comments

Comments
 (0)