diff --git a/docs/docs/python-sdk/guides/python-typing.mdx b/docs/docs/python-sdk/guides/python-typing.mdx index d05b9394..9bc2c323 100644 --- a/docs/docs/python-sdk/guides/python-typing.mdx +++ b/docs/docs/python-sdk/guides/python-typing.mdx @@ -111,7 +111,7 @@ When working with GraphQL queries, you can generate type-safe Pydantic models th Generated Pydantic models from GraphQL queries offer several important benefits: - **Type Safety**: Catch type errors during development time instead of at runtime -- **IDE Support**: Get autocomplete, type hints, and better IntelliSense in your IDE +- **IDE Support**: Get autocomplete, type hints, and better IntelliSense in your IDE - **Documentation**: Generated models serve as living documentation of your GraphQL API - **Validation**: Automatic validation of query responses against the expected schema @@ -120,32 +120,59 @@ Generated Pydantic models from GraphQL queries offer several important benefits: Use the `infrahubctl graphql generate-return-types` command to create Pydantic models from your GraphQL queries: ```shell -# Generate models for queries in current directory -infrahubctl graphql generate-return-types +# Generate models for queries in a directory +infrahubctl graphql generate-return-types queries/ # Generate models for specific query files -infrahubctl graphql generate-return-types queries/get_devices.gql +infrahubctl graphql generate-return-types queries/get_tags.gql ``` -> You can also export the GraphQL schema first using the `infrahubctl graphql export-schema` command: +> You can also export the GraphQL schema first using the `infrahubctl graphql export-schema` command. ### Example workflow -1. **Create your GraphQL queries** in `.gql` files: +1. **Create your GraphQL queries** in `.gql` files preferably in a directory (e.g., `queries/`): + + ```graphql + # queries/get_tags.gql + query GetAllTags { + BuiltinTag { + edges { + node { + __typename + name { + value + } + } + } + } + } + ``` -2. **Generate the Pydantic models**: +2. **Export the GraphQL schema**: + + ```shell + infrahubctl graphql export-schema + ``` + +3. **Generate the Pydantic models**: ```shell infrahubctl graphql generate-return-types queries/ ``` - The command will generate the Python file per query based on the name of the query. + :::warning Query names + + Ensure each of your GraphQL queries has a unique name, as the generated Python files will be named based on these query names. + Two queries with the same name will land in the same file, leading to potential overrides. + + ::: -3. **Use the generated models** in your Python code +4. **Use the generated models** in your Python code ```python - from .queries.get_devices import GetDevicesQuery + from .queries.get_tags import GetAllTagsQuery response = await client.execute_graphql(query=MY_QUERY) - data = GetDevicesQuery(**response) + data = GetAllTagsQuery(**response) ``` diff --git a/infrahub_sdk/ctl/graphql.py b/infrahub_sdk/ctl/graphql.py index 6229d42a..ea0158ce 100644 --- a/infrahub_sdk/ctl/graphql.py +++ b/infrahub_sdk/ctl/graphql.py @@ -22,7 +22,12 @@ from ..async_typer import AsyncTyper from ..ctl.client import initialize_client from ..ctl.utils import catch_exception -from ..graphql.utils import insert_fragments_inline, remove_fragment_import +from ..graphql.utils import ( + insert_fragments_inline, + remove_fragment_import, + strip_typename_from_fragment, + strip_typename_from_operation, +) from .parameters import CONFIG_PARAM app = AsyncTyper() @@ -152,12 +157,18 @@ async def generate_return_types( queries = filter_operations_definitions(definitions) fragments = filter_fragments_definitions(definitions) + # Strip __typename fields from operations and fragments before code generation. + # __typename is a GraphQL introspection meta-field that isn't part of the schema's + # type definitions, causing ariadne-codegen to fail with "Redefinition of reserved type 'String'" + stripped_queries = [strip_typename_from_operation(q) for q in queries] + stripped_fragments = [strip_typename_from_fragment(f) for f in fragments] + package_generator = get_package_generator( schema=graphql_schema, - fragments=fragments, + fragments=stripped_fragments, settings=ClientSettings( schema_path=str(schema), - target_package_name=directory.name, + target_package_name=directory.name or "graphql_client", queries_path=str(directory), include_comments=CommentsStrategy.NONE, ), @@ -166,7 +177,7 @@ async def generate_return_types( parsing_failed = False try: - for query_operation in queries: + for query_operation in stripped_queries: package_generator.add_operation(query_operation) except ParsingError as exc: console.print(f"[red]Unable to process {gql_file.name}: {exc}") diff --git a/infrahub_sdk/graphql/utils.py b/infrahub_sdk/graphql/utils.py index 0756460d..39e4aa4e 100644 --- a/infrahub_sdk/graphql/utils.py +++ b/infrahub_sdk/graphql/utils.py @@ -1,5 +1,90 @@ import ast +from graphql import ( + FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, + InlineFragmentNode, + OperationDefinitionNode, + SelectionNode, + SelectionSetNode, +) + + +def strip_typename_from_selection_set(selection_set: SelectionSetNode | None) -> SelectionSetNode | None: + """Recursively strip __typename fields from a SelectionSetNode. + + The __typename meta-field is an introspection field that is not part of the schema's + type definitions. When code generation tools like ariadne-codegen try to look up + __typename in the schema, they fail because it's a reserved introspection field. + + This function removes all __typename fields from the selection set, allowing + code generation to proceed without errors. + """ + if selection_set is None: + return None + + new_selections: list[SelectionNode] = [] + for selection in selection_set.selections: + if isinstance(selection, FieldNode): + # Skip __typename fields + if selection.name.value == "__typename": + continue + # Recursively process nested selection sets + new_field = FieldNode( + alias=selection.alias, + name=selection.name, + arguments=selection.arguments, + directives=selection.directives, + selection_set=strip_typename_from_selection_set(selection.selection_set), + ) + new_selections.append(new_field) + elif isinstance(selection, InlineFragmentNode): + # Process inline fragments + new_inline = InlineFragmentNode( + type_condition=selection.type_condition, + directives=selection.directives, + selection_set=strip_typename_from_selection_set(selection.selection_set), + ) + new_selections.append(new_inline) + elif isinstance(selection, FragmentSpreadNode): + # FragmentSpread references a named fragment - keep as-is + new_selections.append(selection) + else: + raise TypeError(f"Unexpected GraphQL selection node type '{type(selection).__name__}'.") + + return SelectionSetNode(selections=tuple(new_selections)) + + +def strip_typename_from_operation(operation: OperationDefinitionNode) -> OperationDefinitionNode: + """Strip __typename fields from an operation definition. + + Returns a new OperationDefinitionNode with all __typename fields removed + from its selection set and any nested selection sets. + """ + return OperationDefinitionNode( + operation=operation.operation, + name=operation.name, + variable_definitions=operation.variable_definitions, + directives=operation.directives, + selection_set=strip_typename_from_selection_set(operation.selection_set), + ) + + +def strip_typename_from_fragment(fragment: FragmentDefinitionNode) -> FragmentDefinitionNode: + """Strip __typename fields from a fragment definition. + + Returns a new FragmentDefinitionNode with all __typename fields removed + from its selection set and any nested selection sets. + """ + return FragmentDefinitionNode( + name=fragment.name, + type_condition=fragment.type_condition, + variable_definitions=fragment.variable_definitions, + directives=fragment.directives, + selection_set=strip_typename_from_selection_set(fragment.selection_set), + ) + def get_class_def_index(module: ast.Module) -> int: """Get the index of the first class definition in the module. diff --git a/tests/fixtures/unit/test_infrahubctl/graphql/invalid_query.gql b/tests/fixtures/unit/test_infrahubctl/graphql/invalid_query.gql new file mode 100644 index 00000000..8e45a1d0 --- /dev/null +++ b/tests/fixtures/unit/test_infrahubctl/graphql/invalid_query.gql @@ -0,0 +1,9 @@ +query InvalidQuery { + NonExistentType { + edges { + node { + id + } + } + } +} diff --git a/tests/fixtures/unit/test_infrahubctl/graphql/query_with_typename.gql b/tests/fixtures/unit/test_infrahubctl/graphql/query_with_typename.gql new file mode 100644 index 00000000..43c6d7c3 --- /dev/null +++ b/tests/fixtures/unit/test_infrahubctl/graphql/query_with_typename.gql @@ -0,0 +1,16 @@ +query GetTagsWithTypename($name: String!) { + BuiltinTag(name__value: $name) { + __typename + edges { + __typename + node { + __typename + id + name { + __typename + value + } + } + } + } +} diff --git a/tests/fixtures/unit/test_infrahubctl/graphql/test_schema.graphql b/tests/fixtures/unit/test_infrahubctl/graphql/test_schema.graphql new file mode 100644 index 00000000..e2a631f0 --- /dev/null +++ b/tests/fixtures/unit/test_infrahubctl/graphql/test_schema.graphql @@ -0,0 +1,47 @@ +"""Attribute of type Text""" +type TextAttribute implements AttributeInterface { + is_default: Boolean + is_inherited: Boolean + is_protected: Boolean + is_visible: Boolean + updated_at: DateTime + id: String + value: String +} + +interface AttributeInterface { + is_default: Boolean + is_inherited: Boolean + is_protected: Boolean + is_visible: Boolean + updated_at: DateTime +} + +scalar DateTime + +type BuiltinTag implements CoreNode { + """Unique identifier""" + id: String! + display_label: String + """Description""" + description: TextAttribute + """Name (required)""" + name: TextAttribute +} + +interface CoreNode { + id: String! +} + +type EdgedBuiltinTag { + node: BuiltinTag +} + +type PaginatedBuiltinTag { + count: Int! + edges: [EdgedBuiltinTag!]! +} + +type Query { + BuiltinTag(name__value: String, ids: [ID]): PaginatedBuiltinTag +} diff --git a/tests/fixtures/unit/test_infrahubctl/graphql/valid_query.gql b/tests/fixtures/unit/test_infrahubctl/graphql/valid_query.gql new file mode 100644 index 00000000..574e06d5 --- /dev/null +++ b/tests/fixtures/unit/test_infrahubctl/graphql/valid_query.gql @@ -0,0 +1,15 @@ +query GetTags($name: String!) { + BuiltinTag(name__value: $name) { + edges { + node { + id + name { + value + } + description { + value + } + } + } + } +} diff --git a/tests/unit/ctl/test_graphql_app.py b/tests/unit/ctl/test_graphql_app.py new file mode 100644 index 00000000..07af1d20 --- /dev/null +++ b/tests/unit/ctl/test_graphql_app.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import os +import shutil +from pathlib import Path + +import pytest +from ariadne_codegen.schema import get_graphql_schema_from_path +from typer.testing import CliRunner + +from infrahub_sdk.ctl.graphql import app, find_gql_files, get_graphql_query +from tests.helpers.cli import remove_ansi_color + +runner = CliRunner() + +FIXTURES_DIR = Path(__file__).parent.parent.parent / "fixtures" / "unit" / "test_infrahubctl" / "graphql" + + +class TestFindGqlFiles: + """Tests for find_gql_files helper function.""" + + def test_find_gql_files_single_file(self, tmp_path: Path) -> None: + """Test finding a single .gql file when path points to a file.""" + query_file = tmp_path / "query.gql" + query_file.write_text("query Test { field }") + + result = find_gql_files(query_file) + + assert len(result) == 1 + assert result[0] == query_file + + def test_find_gql_files_directory(self, tmp_path: Path) -> None: + """Test finding multiple .gql files in a directory.""" + (tmp_path / "query1.gql").write_text("query Test1 { field }") + (tmp_path / "query2.gql").write_text("query Test2 { field }") + (tmp_path / "not_a_query.txt").write_text("not a query") + + result = find_gql_files(tmp_path) + + assert len(result) == 2 + assert all(f.suffix == ".gql" for f in result) + + def test_find_gql_files_nested_directory(self, tmp_path: Path) -> None: + """Test finding .gql files in nested directories.""" + subdir = tmp_path / "subdir" + subdir.mkdir() + (tmp_path / "query1.gql").write_text("query Test1 { field }") + (subdir / "query2.gql").write_text("query Test2 { field }") + + result = find_gql_files(tmp_path) + + assert len(result) == 2 + + def test_find_gql_files_nonexistent_path(self, tmp_path: Path) -> None: + """Test that FileNotFoundError is raised for non-existent path.""" + nonexistent = tmp_path / "nonexistent" + + with pytest.raises(FileNotFoundError, match="File or directory not found"): + find_gql_files(nonexistent) + + def test_find_gql_files_empty_directory(self, tmp_path: Path) -> None: + """Test finding no .gql files in an empty directory.""" + result = find_gql_files(tmp_path) + + assert len(result) == 0 + + +class TestGetGraphqlQuery: + """Tests for get_graphql_query helper function.""" + + def test_get_graphql_query_valid(self) -> None: + """Test parsing a valid GraphQL query.""" + schema = get_graphql_schema_from_path(str(FIXTURES_DIR / "test_schema.graphql")) + query_file = FIXTURES_DIR / "valid_query.gql" + + definitions = get_graphql_query(query_file, schema) + + assert len(definitions) == 1 + assert definitions[0].name.value == "GetTags" + + def test_get_graphql_query_invalid(self) -> None: + """Test that invalid query raises ValueError.""" + schema = get_graphql_schema_from_path(str(FIXTURES_DIR / "test_schema.graphql")) + query_file = FIXTURES_DIR / "invalid_query.gql" + + with pytest.raises(ValueError, match="Cannot query field"): + get_graphql_query(query_file, schema) + + def test_get_graphql_query_nonexistent_file(self) -> None: + """Test that FileNotFoundError is raised for non-existent file.""" + schema = get_graphql_schema_from_path(str(FIXTURES_DIR / "test_schema.graphql")) + nonexistent = FIXTURES_DIR / "nonexistent.gql" + + with pytest.raises(FileNotFoundError, match="File not found"): + get_graphql_query(nonexistent, schema) + + def test_get_graphql_query_directory_instead_of_file(self) -> None: + """Test that ValueError is raised when path is a directory.""" + schema = get_graphql_schema_from_path(str(FIXTURES_DIR / "test_schema.graphql")) + + with pytest.raises(ValueError, match="is not a file"): + get_graphql_query(FIXTURES_DIR, schema) + + +class TestGenerateReturnTypesCommand: + """Tests for the generate-return-types CLI command.""" + + def test_generate_return_types_success(self, tmp_path: Path) -> None: + """Test successful generation of return types from a valid query.""" + # Copy fixtures to temp directory + schema_file = tmp_path / "schema.graphql" + query_file = tmp_path / "query.gql" + + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + shutil.copy(FIXTURES_DIR / "valid_query.gql", query_file) + + # Run the command + result = runner.invoke( + app, ["generate-return-types", str(query_file), "--schema", str(schema_file)], catch_exceptions=False + ) + + assert result.exit_code == 0 + clean_output = remove_ansi_color(result.stdout) + assert "Generated" in clean_output + + # Check that a file was generated + generated_files = list(tmp_path.glob("*.py")) + assert len(generated_files) >= 1 + + def test_generate_return_types_directory(self, tmp_path: Path) -> None: + """Test generation when providing a directory of queries.""" + # Copy fixtures to temp directory + schema_file = tmp_path / "schema.graphql" + query_dir = tmp_path / "queries" + query_dir.mkdir() + + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + shutil.copy(FIXTURES_DIR / "valid_query.gql", query_dir / "query.gql") + + # Run the command with directory + result = runner.invoke( + app, ["generate-return-types", str(query_dir), "--schema", str(schema_file)], catch_exceptions=False + ) + + assert result.exit_code == 0 + clean_output = remove_ansi_color(result.stdout) + assert "Generated" in clean_output + + def test_generate_return_types_missing_schema(self, tmp_path: Path) -> None: + """Test error when schema file is missing.""" + query_file = tmp_path / "query.gql" + shutil.copy(FIXTURES_DIR / "valid_query.gql", query_file) + + result = runner.invoke(app, ["generate-return-types", str(query_file), "--schema", "nonexistent.graphql"]) + + assert result.exit_code == 1 + clean_output = remove_ansi_color(result.stdout) + assert "not found" in clean_output.lower() + + def test_generate_return_types_invalid_query(self, tmp_path: Path) -> None: + """Test handling of invalid query (should print error and continue).""" + schema_file = tmp_path / "schema.graphql" + query_file = tmp_path / "query.gql" + + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + shutil.copy(FIXTURES_DIR / "invalid_query.gql", query_file) + + result = runner.invoke(app, ["generate-return-types", str(query_file), "--schema", str(schema_file)]) + + # Should exit successfully but print error message for invalid query + assert result.exit_code == 0 + clean_output = remove_ansi_color(result.stdout) + assert "Error" in clean_output + + def test_generate_return_types_with_typename(self, tmp_path: Path) -> None: + """Test that __typename fields are properly stripped during generation.""" + schema_file = tmp_path / "schema.graphql" + query_file = tmp_path / "query.gql" + + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + shutil.copy(FIXTURES_DIR / "query_with_typename.gql", query_file) + + result = runner.invoke( + app, ["generate-return-types", str(query_file), "--schema", str(schema_file)], catch_exceptions=False + ) + + assert result.exit_code == 0 + clean_output = remove_ansi_color(result.stdout) + assert "Generated" in clean_output + + def test_generate_return_types_default_cwd(self, tmp_path: Path) -> None: + """Test that command defaults to current directory when no query path provided.""" + # Copy fixtures to temp directory + schema_file = tmp_path / "schema.graphql" + query_file = tmp_path / "query.gql" + + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + shutil.copy(FIXTURES_DIR / "valid_query.gql", query_file) + + # Change to temp directory and run without specifying query path + original_dir = os.getcwd() + try: + os.chdir(tmp_path) + result = runner.invoke(app, ["generate-return-types", "--schema", str(schema_file)], catch_exceptions=False) + finally: + os.chdir(original_dir) + + assert result.exit_code == 0 + clean_output = remove_ansi_color(result.stdout) + assert "Generated" in clean_output + + def test_generate_return_types_no_gql_files(self, tmp_path: Path) -> None: + """Test when directory has no .gql files.""" + schema_file = tmp_path / "schema.graphql" + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + + result = runner.invoke(app, ["generate-return-types", str(empty_dir), "--schema", str(schema_file)]) + + # Should exit successfully with no output + assert result.exit_code == 0 + + def test_generate_return_types_multiple_queries_same_dir(self, tmp_path: Path) -> None: + """Test generation with multiple query files in the same directory.""" + schema_file = tmp_path / "schema.graphql" + shutil.copy(FIXTURES_DIR / "test_schema.graphql", schema_file) + + # Create multiple valid queries + query1 = tmp_path / "query1.gql" + query2 = tmp_path / "query2.gql" + + query1.write_text(""" +query GetAllTags { + BuiltinTag { + edges { + node { + id + name { value } + } + } + } +} +""") + query2.write_text(""" +query GetTagByName($name: String!) { + BuiltinTag(name__value: $name) { + edges { + node { + id + description { value } + } + } + } +} +""") + + result = runner.invoke( + app, ["generate-return-types", str(tmp_path), "--schema", str(schema_file)], catch_exceptions=False + ) + + assert result.exit_code == 0 + clean_output = remove_ansi_color(result.stdout) + # Should generate files for both queries + assert clean_output.count("Generated") >= 2 diff --git a/tests/unit/ctl/test_graphql_utils.py b/tests/unit/ctl/test_graphql_utils.py new file mode 100644 index 00000000..d67549aa --- /dev/null +++ b/tests/unit/ctl/test_graphql_utils.py @@ -0,0 +1,242 @@ +from graphql import parse, print_ast + +from infrahub_sdk.graphql.utils import ( + strip_typename_from_fragment, + strip_typename_from_operation, + strip_typename_from_selection_set, +) + + +class TestStripTypename: + def test_strip_typename_from_simple_query(self) -> None: + query = """ + query Test { + BuiltinTag { + __typename + name + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "name" in result_str + assert "BuiltinTag" in result_str + + def test_strip_typename_from_nested_query(self) -> None: + query = """ + query Test { + BuiltinTag { + edges { + node { + __typename + name { + value + } + } + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "name" in result_str + assert "value" in result_str + assert "edges" in result_str + assert "node" in result_str + + def test_strip_typename_from_inline_fragment(self) -> None: + query = """ + query Test { + BuiltinTag { + edges { + node { + __typename + ... on Tag { + __typename + name { + value + } + } + } + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "... on Tag" in result_str + assert "name" in result_str + + def test_strip_typename_from_fragment_definition(self) -> None: + query = """ + fragment TagFields on Tag { + __typename + name { + value + } + description { + value + } + } + """ + doc = parse(query) + fragment = doc.definitions[0] + result = strip_typename_from_fragment(fragment) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "name" in result_str + assert "description" in result_str + assert "TagFields" in result_str + + def test_strip_typename_preserves_fragment_spread(self) -> None: + query = """ + query Test { + BuiltinTag { + ...TagFields + __typename + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "...TagFields" in result_str + + def test_strip_typename_from_empty_selection_set(self) -> None: + result = strip_typename_from_selection_set(None) + assert result is None + + def test_strip_typename_multiple_occurrences(self) -> None: + query = """ + query Test { + __typename + BuiltinTag { + __typename + edges { + __typename + node { + __typename + name { + __typename + value + } + } + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + # Should still have the actual fields + assert "BuiltinTag" in result_str + assert "edges" in result_str + assert "node" in result_str + assert "name" in result_str + assert "value" in result_str + + def test_strip_typename_preserves_aliases(self) -> None: + query = """ + query Test { + tags: BuiltinTag { + __typename + tagName: name { + value + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "tags: BuiltinTag" in result_str + assert "tagName: name" in result_str + + def test_strip_typename_preserves_arguments(self) -> None: + query = """ + query Test { + BuiltinTag(first: 10, name__value: "test") { + __typename + edges { + node { + name { + value + } + } + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "first: 10" in result_str + assert 'name__value: "test"' in result_str + + def test_strip_typename_preserves_directives(self) -> None: + query = """ + query Test { + BuiltinTag @include(if: true) { + __typename + name { + value + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + result_str = print_ast(result) + assert "__typename" not in result_str + assert "@include(if: true)" in result_str + + def test_query_without_typename_unchanged(self) -> None: + query = """ + query Test { + BuiltinTag { + edges { + node { + name { + value + } + } + } + } + } + """ + doc = parse(query) + operation = doc.definitions[0] + result = strip_typename_from_operation(operation) + + # The structure should be effectively the same (modulo formatting) + original_str = print_ast(operation) + result_str = print_ast(result) + # Normalize whitespace for comparison + assert original_str.split() == result_str.split()