Skip to content

Commit dabcbef

Browse files
committed
Implement feedback from PR review
1 parent 5fdfa71 commit dabcbef

File tree

6 files changed

+72
-115
lines changed

6 files changed

+72
-115
lines changed

docs/docs/python-sdk/guides/python-typing.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ from lib.protocols import MyOwnObject
100100
my_object = client.get(MyOwnObject, name__value="example")
101101
```
102102

103-
> if you don't have your own Python module, it's possible to use relative path by having the `procotols.py` in the same directory as your script/transform/generator
103+
> if you don't have your own Python module, it's possible to use relative path by having the `protocols.py` in the same directory as your script/transform/generator
104104
105105
## Generating Pydantic models from GraphQL queries
106106

@@ -110,7 +110,7 @@ When working with GraphQL queries, you can generate type-safe Pydantic models th
110110

111111
Generated Pydantic models from GraphQL queries offer several important benefits:
112112

113-
- **Type Safety**: Catch type errors at development time instead of runtime
113+
- **Type Safety**: Catch type errors during development time instead of at runtime
114114
- **IDE Support**: Get autocomplete, type hints, and better IntelliSense in your IDE
115115
- **Documentation**: Generated models serve as living documentation of your GraphQL API
116116
- **Validation**: Automatic validation of query responses against the expected schema

infrahub_sdk/ctl/graphql.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def find_gql_files(query_path: Path) -> list[Path]:
4747
List of Path objects for all .gql files found
4848
"""
4949
if not query_path.exists():
50-
raise FileNotFoundError(f"Directory not found: {query_path}")
50+
raise FileNotFoundError(f"File or directory not found: {query_path}")
5151

5252
if not query_path.is_dir() and query_path.is_file():
5353
return [query_path]
@@ -105,15 +105,19 @@ async def export_schema(
105105
"""Export the GraphQL schema to a file."""
106106

107107
client = initialize_client()
108-
response = await client._get(url=f"{client.address}/schema.graphql")
109-
destination.write_text(response.text)
108+
schema_text = await client.schema.get_graphql_schema()
109+
110+
destination.parent.mkdir(parents=True, exist_ok=True)
111+
destination.write_text(schema_text)
110112
console.print(f"[green]Schema exported to {destination}")
111113

112114

113115
@app.command()
114116
@catch_exception(console=console)
115117
async def generate_return_types(
116-
query: Optional[Path] = typer.Argument(None, help="Location of the GraphQL query file(s)."),
118+
query: Optional[Path] = typer.Argument(
119+
None, help="Location of the GraphQL query file(s). Defaults to current directory if not specified."
120+
),
117121
schema: Path = typer.Option("schema.graphql", help="Path to the GraphQL schema file."),
118122
_: str = CONFIG_PARAM,
119123
) -> None:
@@ -144,7 +148,7 @@ async def generate_return_types(
144148
try:
145149
definitions = get_graphql_query(queries_path=gql_file, schema=graphql_schema)
146150
except ValueError as exc:
147-
print(f"Error generating result types for {gql_file}: {exc}")
151+
console.print(f"[red]Error generating result types for {gql_file}: {exc}")
148152
continue
149153
queries = filter_operations_definitions(definitions)
150154
fragments = filter_fragments_definitions(definitions)
@@ -161,11 +165,17 @@ async def generate_return_types(
161165
plugin_manager=plugin_manager,
162166
)
163167

168+
parsing_failed = False
164169
try:
165170
for query_operation in queries:
166171
package_generator.add_operation(query_operation)
167172
except ParsingError as exc:
168173
console.print(f"[red]Unable to process {gql_file.name}: {exc}")
174+
parsing_failed = True
175+
176+
if parsing_failed:
177+
continue
178+
169179
module_fragment = package_generator.fragments_generator.generate()
170180

171181
generate_result_types(directory=directory, package=package_generator, fragment=module_fragment)

infrahub_sdk/graphql/plugin.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ class FutureAnnotationPlugin(Plugin):
1414
def insert_future_annotation(module: ast.Module) -> ast.Module:
1515
# First check if the future annotation is already present
1616
for item in module.body:
17-
if isinstance(item, ast.ImportFrom) and item.module == "__future__" and item.names[0].name == "annotations":
18-
return module
17+
if isinstance(item, ast.ImportFrom) and item.module == "__future__":
18+
if any(alias.name == "annotations" for alias in item.names):
19+
return module
1920

2021
module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0))
2122
return module
@@ -68,13 +69,11 @@ def find_base_model_index(module: ast.Module) -> int:
6869
for idx, item in enumerate(module.body):
6970
if isinstance(item, ast.ImportFrom) and item.module == "base_model":
7071
return idx
71-
return -1
72+
raise ValueError("BaseModel not found in module")
7273

7374
@classmethod
7475
def replace_base_model_import(cls, module: ast.Module) -> ast.Module:
7576
base_model_index = cls.find_base_model_index(module)
76-
if base_model_index == -1:
77-
raise ValueError("BaseModel not found in module")
7877
module.body[base_model_index] = ast.ImportFrom(module="pydantic", names=[ast.alias(name="BaseModel")], level=0)
7978
return module
8079

infrahub_sdk/graphql/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,22 @@ def get_class_def_index(module: ast.Module) -> int:
1111

1212

1313
def insert_fragments_inline(module: ast.Module, fragment: ast.Module) -> ast.Module:
14-
"""Insert the Pydantic classes for the fragments inline into the module."""
14+
"""Insert the Pydantic classes for the fragments inline into the module.
15+
16+
If no class definitions exist in module, fragments are appended to the end.
17+
"""
1518
module_class_def_index = get_class_def_index(module)
1619

1720
fragment_classes: list[ast.ClassDef] = [item for item in fragment.body if isinstance(item, ast.ClassDef)]
18-
for idx, item in enumerate(fragment_classes):
19-
module.body.insert(module_class_def_index + idx, item)
21+
22+
# Handle edge case when no class definitions exist
23+
if module_class_def_index == -1:
24+
# Append fragments to the end of the module
25+
module.body.extend(fragment_classes)
26+
else:
27+
# Insert fragments before the first class definition
28+
for idx, item in enumerate(fragment_classes):
29+
module.body.insert(module_class_def_index + idx, item)
2030

2131
return module
2232

infrahub_sdk/schema/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,25 @@ async def fetch(
474474

475475
return branch_schema.nodes
476476

477+
async def get_graphql_schema(self, branch: str | None = None) -> str:
478+
"""Get the GraphQL schema as a string.
479+
480+
Args:
481+
branch: The branch to get the schema for. Defaults to default_branch.
482+
483+
Returns:
484+
The GraphQL schema as a string.
485+
"""
486+
branch = branch or self.client.default_branch
487+
url = f"{self.client.address}/schema.graphql"
488+
489+
response = await self.client._get(url=url)
490+
491+
if response.status_code != 200:
492+
raise ValueError(f"Failed to fetch GraphQL schema: HTTP {response.status_code} - {response.text}")
493+
494+
return response.text
495+
477496
async def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
478497
url_parts = [("branch", branch)]
479498
if namespaces:
@@ -697,6 +716,25 @@ def fetch(
697716

698717
return branch_schema.nodes
699718

719+
def get_graphql_schema(self, branch: str | None = None) -> str:
720+
"""Get the GraphQL schema as a string.
721+
722+
Args:
723+
branch: The branch to get the schema for. Defaults to default_branch.
724+
725+
Returns:
726+
The GraphQL schema as a string.
727+
"""
728+
branch = branch or self.client.default_branch
729+
url = f"{self.client.address}/schema.graphql"
730+
731+
response = self.client._get(url=url)
732+
733+
if response.status_code != 200:
734+
raise ValueError(f"Failed to fetch GraphQL schema: HTTP {response.status_code} - {response.text}")
735+
736+
return response.text
737+
700738
def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
701739
url_parts = [("branch", branch)]
702740
if namespaces:

tests/unit/sdk/graphql/test_renderer.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,6 @@
1-
from enum import Enum
2-
3-
import pytest
4-
51
from infrahub_sdk.graphql.renderers import render_input_block, render_query_block
62

73

8-
class MyStrEnum(str, Enum):
9-
VALUE1 = "value1"
10-
VALUE2 = "value2"
11-
12-
13-
class MyIntEnum(int, Enum):
14-
VALUE1 = 12
15-
VALUE2 = 24
16-
17-
18-
@pytest.fixture
19-
def query_data_no_filter():
20-
data = {
21-
"device": {
22-
"name": {"value": None},
23-
"description": {"value": None},
24-
"interfaces": {"name": {"value": None}},
25-
}
26-
}
27-
28-
return data
29-
30-
31-
@pytest.fixture
32-
def query_data_alias():
33-
data = {
34-
"device": {
35-
"name": {"@alias": "new_name", "value": None},
36-
"description": {"value": {"@alias": "myvalue"}},
37-
"interfaces": {"@alias": "myinterfaces", "name": {"value": None}},
38-
}
39-
}
40-
41-
return data
42-
43-
44-
@pytest.fixture
45-
def query_data_fragment():
46-
data = {
47-
"device": {
48-
"name": {"value": None},
49-
"...on Builtin": {
50-
"description": {"value": None},
51-
"interfaces": {"name": {"value": None}},
52-
},
53-
}
54-
}
55-
56-
return data
57-
58-
59-
@pytest.fixture
60-
def query_data_empty_filter():
61-
data = {
62-
"device": {
63-
"@filters": {},
64-
"name": {"value": None},
65-
"description": {"value": None},
66-
"interfaces": {"name": {"value": None}},
67-
}
68-
}
69-
70-
return data
71-
72-
73-
@pytest.fixture
74-
def query_data_filters_01():
75-
data = {
76-
"device": {
77-
"@filters": {"name__value": "$name"},
78-
"name": {"value": None},
79-
"description": {"value": None},
80-
"interfaces": {
81-
"@filters": {"enabled__value": "$enabled"},
82-
"name": {"value": None},
83-
},
84-
}
85-
}
86-
return data
87-
88-
89-
@pytest.fixture
90-
def query_data_filters_02():
91-
data = {
92-
"device": {
93-
"@filters": {"name__value": "myname", "integer__value": 44, "enumstr__value": MyStrEnum.VALUE2},
94-
"name": {"value": None},
95-
"interfaces": {
96-
"@filters": {"enabled__value": True, "enumint__value": MyIntEnum.VALUE1},
97-
"name": {"value": None},
98-
},
99-
}
100-
}
101-
return data
102-
103-
1044
def test_render_query_block(query_data_no_filter) -> None:
1055
lines = render_query_block(data=query_data_no_filter)
1066

0 commit comments

Comments
 (0)