Skip to content

Commit b7f7404

Browse files
feat: user-specified validation rule skips (#340)
* user-specified validation rule skips This feature allows for the user to specify which validation rules they want to skip. Currently we only support 2 rules (the implicit no unused fragments rule and the UniqueFragmentNames rule). * standarize validation rules configuration to be spread to all possible not only to two --------- Co-authored-by: Damian Czajkowski <d0.czajkowski@gmail.com>
1 parent aa3b82d commit b7f7404

File tree

4 files changed

+169
-6
lines changed

4 files changed

+169
-6
lines changed

ariadne_codegen/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
get_graphql_schema_from_path,
2020
get_graphql_schema_from_url,
2121
)
22-
from .settings import Strategy
22+
from .settings import Strategy, get_validation_rule
2323

2424

2525
@click.command()
@@ -66,7 +66,11 @@ def client(config_dict):
6666
fragments = []
6767
queries = []
6868
if settings.queries_path:
69-
definitions = get_graphql_queries(settings.queries_path, schema)
69+
definitions = get_graphql_queries(
70+
settings.queries_path,
71+
schema,
72+
[get_validation_rule(e) for e in settings.skip_validation_rules],
73+
)
7074
queries = filter_operations_definitions(definitions)
7175
fragments = filter_fragments_definitions(definitions)
7276

ariadne_codegen/schema.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Generator
1+
from collections.abc import Generator, Sequence
22
from dataclasses import asdict
33
from pathlib import Path
44
from typing import Optional, cast
@@ -23,6 +23,7 @@
2323
specified_rules,
2424
validate,
2525
)
26+
from typing_extensions import Any
2627

2728
from .client_generators.constants import MIXIN_FROM_NAME, MIXIN_IMPORT_NAME, MIXIN_NAME
2829
from .exceptions import (
@@ -48,15 +49,17 @@ def filter_fragments_definitions(
4849

4950

5051
def get_graphql_queries(
51-
queries_path: str, schema: GraphQLSchema
52+
queries_path: str,
53+
schema: GraphQLSchema,
54+
skip_rules: Sequence[Any] = (NoUnusedFragmentsRule,),
5255
) -> tuple[DefinitionNode, ...]:
5356
"""Get graphql queries definitions build from provided path."""
5457
queries_str = load_graphql_files_from_path(Path(queries_path))
5558
queries_ast = parse(queries_str)
5659
validation_errors = validate(
5760
schema=schema,
5861
document_ast=queries_ast,
59-
rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule],
62+
rules=[r for r in specified_rules if r not in skip_rules],
6063
)
6164
if validation_errors:
6265
raise InvalidOperationForSchema(

ariadne_codegen/settings.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pathlib import Path
77
from textwrap import dedent
88

9+
from graphql.validation import specified_rules
10+
911
from .client_generators.constants import (
1012
DEFAULT_ASYNC_BASE_CLIENT_NAME,
1113
DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME,
@@ -26,6 +28,21 @@ class CommentsStrategy(str, enum.Enum):
2628
TIMESTAMP = "timestamp"
2729

2830

31+
VALIDATION_RULES_MAP = {
32+
rule.__name__.removesuffix("Rule"): rule for rule in specified_rules
33+
}
34+
35+
36+
def get_validation_rule(rule: str):
37+
try:
38+
return VALIDATION_RULES_MAP[rule]
39+
except KeyError as exc:
40+
supported_rules = ", ".join(sorted(VALIDATION_RULES_MAP))
41+
raise ValueError(
42+
f"Unknown validation rule: {rule}. Supported values are: {supported_rules}"
43+
) from exc
44+
45+
2946
class Strategy(str, enum.Enum):
3047
CLIENT = "client"
3148
GRAPHQL_SCHEMA = "graphqlschema"
@@ -125,6 +142,11 @@ class ClientSettings(BaseSettings):
125142
include_all_enums: bool = True
126143
async_client: bool = True
127144
opentelemetry_client: bool = False
145+
skip_validation_rules: list[str] = field(
146+
default_factory=lambda: [
147+
"NoUnusedFragments",
148+
]
149+
)
128150
files_to_include: list[str] = field(default_factory=list)
129151
scalars: dict[str, ScalarData] = field(default_factory=dict)
130152
default_optional_fields_to_none: bool = False

tests/test_schema.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
read_graphql_file,
1818
walk_graphql_files,
1919
)
20-
from ariadne_codegen.settings import IntrospectionSettings
20+
from ariadne_codegen.settings import IntrospectionSettings, get_validation_rule
2121

2222

2323
@pytest.fixture
@@ -67,6 +67,53 @@ def test_query_2_str():
6767
"""
6868

6969

70+
@pytest.fixture
71+
def test_fragment_str():
72+
return """
73+
fragment fragmentA on Custom {
74+
node
75+
}
76+
query testQuery2 {
77+
test {
78+
default
79+
...fragmentA
80+
}
81+
}
82+
"""
83+
84+
85+
@pytest.fixture
86+
def test_duplicate_fragment_str():
87+
return """
88+
fragment fragmentA on Custom {
89+
node
90+
}
91+
fragment fragmentA on Custom {
92+
node
93+
}
94+
query testQuery2 {
95+
test {
96+
default
97+
...fragmentA
98+
}
99+
}
100+
"""
101+
102+
103+
@pytest.fixture
104+
def test_unused_fragment_str():
105+
return """
106+
fragment fragmentA on Custom {
107+
node
108+
}
109+
query testQuery2 {
110+
test {
111+
default
112+
}
113+
}
114+
"""
115+
116+
70117
@pytest.fixture
71118
def single_file_schema(tmp_path_factory, schema_str):
72119
file_ = tmp_path_factory.mktemp("schema").joinpath("schema.graphql")
@@ -136,6 +183,37 @@ def single_file_query(tmp_path_factory, test_query_str):
136183
return file_
137184

138185

186+
@pytest.fixture
187+
def single_file_query_with_fragment(
188+
tmp_path_factory, test_query_str, test_fragment_str
189+
):
190+
file_ = tmp_path_factory.mktemp("queries").joinpath("query1_fragment.graphql")
191+
file_.write_text(test_query_str + test_fragment_str, encoding="utf-8")
192+
return file_
193+
194+
195+
@pytest.fixture
196+
def single_file_query_with_duplicate_fragment(
197+
tmp_path_factory, test_query_str, test_duplicate_fragment_str
198+
):
199+
file_ = tmp_path_factory.mktemp("queries").joinpath(
200+
"query1_duplicate_fragment.graphql"
201+
)
202+
file_.write_text(test_query_str + test_duplicate_fragment_str, encoding="utf-8")
203+
return file_
204+
205+
206+
@pytest.fixture
207+
def single_file_query_with_unused_fragment(
208+
tmp_path_factory, test_query_str, test_unused_fragment_str
209+
):
210+
file_ = tmp_path_factory.mktemp("queries").joinpath(
211+
"query1_unused_fragment.graphql"
212+
)
213+
file_.write_text(test_query_str + test_unused_fragment_str, encoding="utf-8")
214+
return file_
215+
216+
139217
@pytest.fixture
140218
def invalid_syntax_query_file(tmp_path_factory):
141219
file_ = tmp_path_factory.mktemp("queries").joinpath("query.graphql")
@@ -449,6 +527,62 @@ def test_get_graphql_queries_with_invalid_query_for_schema_raises_invalid_operat
449527
)
450528

451529

530+
def test_get_graphql_queries_with_fragment_returns_schema_definitions(
531+
single_file_query_with_fragment, schema_str
532+
):
533+
queries = get_graphql_queries(
534+
single_file_query_with_fragment.as_posix(), build_schema(schema_str)
535+
)
536+
537+
assert len(queries) == 3
538+
539+
540+
def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation(
541+
single_file_query_with_duplicate_fragment, schema_str
542+
):
543+
with pytest.raises(InvalidOperationForSchema):
544+
get_graphql_queries(
545+
single_file_query_with_duplicate_fragment.as_posix(),
546+
build_schema(schema_str),
547+
)
548+
549+
550+
def test_unused_fragment_without_skips_raises_invalid_operation(
551+
single_file_query_with_unused_fragment,
552+
schema_str,
553+
):
554+
with pytest.raises(InvalidOperationForSchema):
555+
get_graphql_queries(
556+
single_file_query_with_unused_fragment.as_posix(),
557+
build_schema(schema_str),
558+
[],
559+
)
560+
561+
562+
def test_duplicate_fragment_passes_when_skip_rule_enabled(
563+
single_file_query_with_duplicate_fragment,
564+
schema_str,
565+
):
566+
get_graphql_queries(
567+
single_file_query_with_duplicate_fragment.as_posix(),
568+
build_schema(schema_str),
569+
[
570+
get_validation_rule("NoUnusedFragments"),
571+
get_validation_rule("UniqueFragmentNames"),
572+
],
573+
)
574+
575+
576+
def test_get_validation_rule_accepts_all_specified_rule_names():
577+
rule = get_validation_rule("NoUnusedVariables")
578+
assert rule.__name__ == "NoUnusedVariablesRule"
579+
580+
581+
def test_get_validation_rule_with_unknown_rule_raises_value_error():
582+
with pytest.raises(ValueError):
583+
get_validation_rule("UnknownRule")
584+
585+
452586
def test_introspect_remote_schema_passes_introspection_settings_to_introspection_query(
453587
mocker,
454588
):

0 commit comments

Comments
 (0)