Skip to content

Commit 2d235cd

Browse files
committed
Improve support for sub directory
1 parent f27fe6c commit 2d235cd

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

infrahub_sdk/ctl/graphql.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
from pathlib import Path
55

66
import typer
7-
from ariadne_codegen.client_generators.package import get_package_generator
7+
from ariadne_codegen.client_generators.package import PackageGenerator, get_package_generator
88
from ariadne_codegen.plugins.explorer import get_plugins_types
99
from ariadne_codegen.plugins.manager import PluginManager
1010
from ariadne_codegen.schema import (
1111
filter_operations_definitions,
12-
get_graphql_queries,
1312
get_graphql_schema_from_path,
1413
)
1514
from ariadne_codegen.settings import ClientSettings, CommentsStrategy
15+
from ariadne_codegen.utils import ast_to_str
16+
from graphql import DefinitionNode, GraphQLSchema, NoUnusedFragmentsRule, parse, specified_rules, validate
1617
from rich.console import Console
1718

1819
from ..async_typer import AsyncTyper
@@ -39,10 +40,39 @@ def find_gql_files(query_path: Path) -> list[Path]:
3940
raise FileNotFoundError(f"Directory not found: {query_path}")
4041

4142
if not query_path.is_dir() and query_path.is_file():
42-
yield query_path
43+
return [query_path]
4344

44-
else:
45-
yield from query_path.glob("*/*.gql")
45+
return list(query_path.glob("**/*.gql"))
46+
47+
48+
def get_graphql_query(queries_path: Path, schema: GraphQLSchema) -> tuple[DefinitionNode, ...]:
49+
"""Get graphql queries definitions from a single GraphQL file."""
50+
51+
if not queries_path.exists():
52+
raise FileNotFoundError(f"File not found: {queries_path}")
53+
if not queries_path.is_file():
54+
raise ValueError(f"{queries_path} is not a file")
55+
56+
queries_str = queries_path.read_text(encoding="utf-8")
57+
queries_ast = parse(queries_str)
58+
validation_errors = validate(
59+
schema=schema,
60+
document_ast=queries_ast,
61+
rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule],
62+
)
63+
if validation_errors:
64+
raise ValueError("\n\n".join(error.message for error in validation_errors))
65+
return queries_ast.definitions
66+
67+
68+
def generate_result_types(directory: Path, package: PackageGenerator) -> None:
69+
for file_name, module in package._result_types_files.items():
70+
file_path = directory / file_name
71+
code = package._add_comments_to_code(ast_to_str(module), package.queries_source)
72+
if package.plugin_manager:
73+
code = package.plugin_manager.generate_result_types_code(code)
74+
file_path.write_text(code)
75+
package._generated_files.append(file_path.name)
4676

4777

4878
@app.callback()
@@ -79,7 +109,7 @@ async def generate_return_types(
79109
gql_per_directory[gql_file.parent].append(gql_file)
80110

81111
# Generate the Pydantic Models for the GraphQL queries
82-
for directory in gql_per_directory.keys():
112+
for directory, gql_files in gql_per_directory.items():
83113
package_generator = get_package_generator(
84114
schema=graphql_schema,
85115
fragments=[],
@@ -92,13 +122,19 @@ async def generate_return_types(
92122
plugin_manager=plugin_manager,
93123
)
94124

95-
definitions = get_graphql_queries(queries_path=str(directory), schema=graphql_schema)
96-
queries = filter_operations_definitions(definitions)
125+
for gql_file in gql_files:
126+
try:
127+
definitions = get_graphql_query(queries_path=gql_file, schema=graphql_schema)
128+
except ValueError as e:
129+
print(f"Error generating result types for {gql_file}: {e}")
130+
continue
131+
queries = filter_operations_definitions(definitions)
97132

98-
for query_operation in queries:
99-
package_generator.add_operation(query_operation)
133+
for query_operation in queries:
134+
package_generator.add_operation(query_operation)
100135

101-
package_generator._generate_result_types()
136+
# package_generator._generate_result_types()
137+
generate_result_types(directory=directory, package=package_generator)
102138

103139
for file_name in package_generator._result_types_files.keys():
104140
print(f"Generated {file_name} in {directory}")

0 commit comments

Comments
 (0)