Skip to content

Commit c098e0c

Browse files
authored
fix: QUDT handling (#156)
* fix(qudt): add optional directory flag to units sync and remove automatic units import logic Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * feat(schema): implement and integrate schema correctness validation Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * fix(tests): add units_directory fixture and update tests to use it Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * style: apply formatting fixes Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * feat(cli): add schema correctness assertion in compose function Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * test(cli): simplify units_directory fixture to return test data path Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * fix(schema): fix pre-commit issues Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> * refactor(schema): assert schema correctness in cli.py Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de> --------- Signed-off-by: Ahmed Mohamed <ahmed.mohamed@motius.de>
1 parent 502f50e commit c098e0c

File tree

5 files changed

+729
-80
lines changed

5 files changed

+729
-80
lines changed

src/s2dm/cli.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import rich_click as click
99
import yaml
10-
from graphql import build_schema, parse
10+
from graphql import GraphQLSchema, build_schema, parse
1111
from rich.traceback import install
1212

1313
from s2dm import __version__, log
@@ -21,6 +21,7 @@
2121
from s2dm.exporters.utils.graphql_type import is_builtin_scalar_type, is_introspection_type
2222
from s2dm.exporters.utils.schema import load_schema_with_naming, search_schema
2323
from s2dm.exporters.utils.schema_loader import (
24+
check_correct_schema,
2425
create_tempfile_to_composed_schema,
2526
load_schema,
2627
load_schema_as_str,
@@ -52,15 +53,6 @@ def process_value(self, ctx: click.Context, value: Any) -> list[Path] | None:
5253
if not value:
5354
return None
5455
paths = set(value)
55-
56-
# Include the default QUDT units directory if it exists, otherwise warn and don't include it
57-
if DEFAULT_QUDT_UNITS_DIR.exists():
58-
paths.add(DEFAULT_QUDT_UNITS_DIR)
59-
else:
60-
log.warning(
61-
f"No QUDT units directory found at {DEFAULT_QUDT_UNITS_DIR}. Please run 's2dm units sync' first."
62-
)
63-
6456
return resolve_graphql_files(list(paths))
6557

6658

@@ -102,6 +94,7 @@ def selection_query_option(required: bool = False) -> Callable[[Callable[..., An
10294
help="Output file",
10395
)
10496

97+
10598
optional_output_option = click.option(
10699
"--output",
107100
"-o",
@@ -111,6 +104,16 @@ def selection_query_option(required: bool = False) -> Callable[[Callable[..., An
111104
)
112105

113106

107+
units_directory_option = click.option(
108+
"--directory",
109+
"-d",
110+
type=click.Path(file_okay=False, path_type=Path),
111+
default=DEFAULT_QUDT_UNITS_DIR,
112+
help="Directory for QUDT unit enums",
113+
show_default=True,
114+
)
115+
116+
114117
expanded_instances_option = click.option(
115118
"--expanded-instances",
116119
"-e",
@@ -140,6 +143,16 @@ def multiline_str_representer(obj: Any) -> Any:
140143
return {k: multiline_str_representer(v) for k, v in result.items()}
141144

142145

146+
def assert_correct_schema(schema: GraphQLSchema) -> None:
147+
schema_errors = check_correct_schema(schema)
148+
if schema_errors:
149+
log.error("Schema validation failed:")
150+
for error in schema_errors:
151+
log.error(error)
152+
log.error(f"Found {len(schema_errors)} validation error(s). Please fix the schema before exporting.")
153+
sys.exit(1)
154+
155+
143156
def validate_naming_config(config: dict[str, Any]) -> None:
144157
VALID_CASES = {
145158
"camelCase",
@@ -315,39 +328,48 @@ def units() -> None:
315328
"QUDT version tag (e.g., 3.1.6). Defaults to the latest tag, falls back to 'main' when tags are unavailable."
316329
),
317330
)
331+
@units_directory_option
318332
@click.option(
319333
"--dry-run",
320334
is_flag=True,
321335
help="Show what would be generated without actually writing files",
322336
)
323-
def units_sync(version: str | None, dry_run: bool) -> None:
324-
"""Fetch QUDT quantity kinds and generate GraphQL enums under the output directory."""
337+
def units_sync(version: str | None, directory: Path, dry_run: bool) -> None:
338+
"""Fetch QUDT quantity kinds and generate GraphQL enums under the specified directory.
339+
340+
Args:
341+
version: QUDT version tag. Defaults to the latest tag.
342+
directory: Output directory for generated QUDT unit enums (default: ~/.s2dm/units/qudt)
343+
dry_run: Show what would be generated without actually writing files
344+
"""
325345

326346
version_to_use = version or get_latest_qudt_version()
327347

328348
try:
329-
written = sync_qudt_units(DEFAULT_QUDT_UNITS_DIR, version_to_use, dry_run=dry_run)
349+
written = sync_qudt_units(directory, version_to_use, dry_run=dry_run)
330350
except UnitEnumError as e:
331351
log.error(f"Units sync failed: {e}")
332352
sys.exit(1)
333353

334354
if dry_run:
335-
log.info(f"Would generate {len(written)} enum files under {DEFAULT_QUDT_UNITS_DIR}")
355+
log.info(f"Would generate {len(written)} enum files under {directory}")
336356
log.print(f"Version: {version_to_use}")
337357
log.hint("Use without --dry-run to actually write files")
338358
else:
339-
log.success(f"Generated {len(written)} enum files under {DEFAULT_QUDT_UNITS_DIR}")
359+
log.success(f"Generated {len(written)} enum files under {directory}")
340360
log.print(f"Version: {version_to_use}")
341361

342362

343363
@units.command(name="check-version")
344-
def units_check_version() -> None:
364+
@units_directory_option
365+
def units_check_version(directory: Path) -> None:
345366
"""Compare local synced QUDT version with the latest remote version and print a message.
367+
346368
Args:
347-
qudt_units_dir: Directory containing generated QUDT unit enums (default: ~/.s2dm/units/qudt)
369+
directory: Directory containing generated QUDT unit enums (default: ~/.s2dm/units/qudt)
348370
"""
349371

350-
meta_path = DEFAULT_QUDT_UNITS_DIR / UNITS_META_FILENAME
372+
meta_path = directory / UNITS_META_FILENAME
351373
if not meta_path.exists():
352374
log.warning("No metadata.json found. Run 's2dm units sync' first.")
353375
sys.exit(1)
@@ -398,6 +420,7 @@ def compose(schemas: list[Path], root_type: str | None, selection_query: Path |
398420
composed_schema_str = load_schema_as_str(schemas, add_references=True)
399421

400422
graphql_schema = build_schema(composed_schema_str)
423+
assert_correct_schema(graphql_schema)
401424

402425
if selection_query:
403426
query_document = parse(selection_query.read_text())
@@ -487,6 +510,7 @@ def shacl(
487510
naming_config = ctx.obj.get("naming_config")
488511

489512
graphql_schema = load_schema_with_naming(schemas, naming_config)
513+
assert_correct_schema(graphql_schema)
490514

491515
if selection_query:
492516
query_document = parse(selection_query.read_text())
@@ -515,6 +539,7 @@ def vspec(ctx: click.Context, schemas: list[Path], selection_query: Path | None,
515539
"""Generate VSPEC from a given GraphQL schema."""
516540
naming_config = ctx.obj.get("naming_config")
517541
graphql_schema = load_schema_with_naming(schemas, naming_config)
542+
assert_correct_schema(graphql_schema)
518543

519544
if selection_query:
520545
query_document = parse(selection_query.read_text())
@@ -553,6 +578,7 @@ def jsonschema(
553578
"""Generate JSON Schema from a given GraphQL schema."""
554579
naming_config = ctx.obj.get("naming_config")
555580
graphql_schema = load_schema_with_naming(schemas, naming_config)
581+
assert_correct_schema(graphql_schema)
556582

557583
if selection_query:
558584
query_document = parse(selection_query.read_text())
@@ -597,6 +623,7 @@ def protobuf(
597623
"""Generate Protocol Buffers (.proto) file from GraphQL schema."""
598624
naming_config = ctx.obj.get("naming_config")
599625
graphql_schema = load_schema_with_naming(schemas, naming_config)
626+
assert_correct_schema(graphql_schema)
600627

601628
query_document = parse(selection_query.read_text())
602629
graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document)

src/s2dm/exporters/utils/schema_loader.py

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ariadne import load_schema_from_path
77
from graphql import (
88
DocumentNode,
9+
GraphQLEnumType,
910
GraphQLField,
1011
GraphQLInputObjectType,
1112
GraphQLInterfaceType,
@@ -17,6 +18,7 @@
1718
GraphQLString,
1819
GraphQLType,
1920
GraphQLUnionType,
21+
Undefined,
2022
build_schema,
2123
get_named_type,
2224
is_input_object_type,
@@ -26,9 +28,10 @@
2628
is_object_type,
2729
is_union_type,
2830
print_schema,
31+
validate_schema,
2932
)
3033
from graphql import validate as graphql_validate
31-
from graphql.language.ast import SelectionSetNode
34+
from graphql.language.ast import DirectiveNode, EnumValueNode, SelectionSetNode
3235

3336
from s2dm import log
3437
from s2dm.exporters.utils.directive import (
@@ -268,6 +271,150 @@ def create_tempfile_to_composed_schema(graphql_schema_paths: list[Path]) -> Path
268271
return Path(temp_path)
269272

270273

274+
def _check_directive_usage_on_node(schema: GraphQLSchema, directive_node: DirectiveNode, context: str) -> list[str]:
275+
"""Check enum values in directive usage on a specific node."""
276+
errors: list[str] = []
277+
278+
directive_def = schema.get_directive(directive_node.name.value)
279+
if not directive_def:
280+
return errors
281+
282+
for arg_node in directive_node.arguments:
283+
arg_name = arg_node.name.value
284+
arg_def = directive_def.args[arg_name]
285+
named_type = get_named_type(arg_def.type)
286+
287+
if not isinstance(named_type, GraphQLEnumType):
288+
continue
289+
290+
if not isinstance(arg_node.value, EnumValueNode):
291+
continue
292+
293+
enum_value = arg_node.value.value
294+
if enum_value not in named_type.values:
295+
errors.append(
296+
f"{context} uses directive '@{directive_node.name.value}({arg_name})' "
297+
f"with invalid enum value '{enum_value}'. Valid values are: {list(named_type.values.keys())}"
298+
)
299+
300+
return errors
301+
302+
303+
def check_enum_defaults(schema: GraphQLSchema) -> list[str]:
304+
"""Check that all enum default values exist in their enum definitions.
305+
306+
Args:
307+
schema: The GraphQL schema to validate
308+
309+
Returns:
310+
List of error messages for invalid enum defaults
311+
"""
312+
errors = []
313+
314+
for type_name, type_obj in schema.type_map.items():
315+
# Validate directive usage on types
316+
if type_obj.ast_node and type_obj.ast_node.directives:
317+
for directive_node in type_obj.ast_node.directives:
318+
errors.extend(_check_directive_usage_on_node(schema, directive_node, f"Type '{type_name}'"))
319+
320+
# Validate input object field defaults
321+
if isinstance(type_obj, GraphQLInputObjectType):
322+
for field_name, field in type_obj.fields.items():
323+
named_type = get_named_type(field.type)
324+
if not isinstance(named_type, GraphQLEnumType):
325+
continue
326+
327+
has_default_in_ast = field.ast_node and field.ast_node.default_value is not None
328+
if not (has_default_in_ast and field.default_value is Undefined):
329+
continue
330+
331+
invalid_value = field.ast_node.default_value.value
332+
errors.append(
333+
f"Input type '{type_name}.{field_name}' has invalid enum default value '{invalid_value}'. "
334+
f"Valid values are: {list(named_type.values.keys())}"
335+
)
336+
337+
# Validate field argument defaults and directive usage on fields
338+
if isinstance(type_obj, GraphQLObjectType | GraphQLInterfaceType | GraphQLInputObjectType):
339+
for field_name, field in type_obj.fields.items():
340+
# Validate directive usage on fields
341+
if field.ast_node and field.ast_node.directives:
342+
for directive_node in field.ast_node.directives:
343+
errors.extend(
344+
_check_directive_usage_on_node(schema, directive_node, f"Field '{type_name}.{field_name}'")
345+
)
346+
347+
# Validate field argument defaults
348+
if isinstance(type_obj, GraphQLObjectType | GraphQLInterfaceType):
349+
for arg_name, arg in field.args.items():
350+
named_type = get_named_type(arg.type)
351+
if not isinstance(named_type, GraphQLEnumType):
352+
continue
353+
354+
has_default_in_ast = arg.ast_node and arg.ast_node.default_value is not None
355+
if not (has_default_in_ast and arg.default_value is Undefined):
356+
continue
357+
358+
invalid_value = arg.ast_node.default_value.value
359+
errors.append(
360+
f"Field argument '{type_name}.{field_name}({arg_name})' "
361+
f"has invalid enum default value '{invalid_value}'. "
362+
f"Valid values are: {list(named_type.values.keys())}"
363+
)
364+
365+
# Validate directive definition defaults
366+
for directive in schema.directives:
367+
for arg_name, arg in directive.args.items():
368+
named_type = get_named_type(arg.type)
369+
if not isinstance(named_type, GraphQLEnumType):
370+
continue
371+
372+
if not arg.ast_node or not arg.ast_node.default_value:
373+
continue
374+
375+
if arg.default_value is not Undefined:
376+
continue
377+
378+
if not isinstance(arg.ast_node.default_value, EnumValueNode):
379+
continue
380+
381+
invalid_value = arg.ast_node.default_value.value
382+
errors.append(
383+
f"Directive definition '@{directive.name}({arg_name})' "
384+
f"has invalid enum default value '{invalid_value}'. Valid values are: {list(named_type.values.keys())}"
385+
)
386+
387+
return errors
388+
389+
390+
def check_correct_schema(schema: GraphQLSchema) -> list[str]:
391+
"""Assert that the schema conforms to GraphQL specification and has valid enum defaults.
392+
393+
Args:
394+
schema: The GraphQL schema to validate
395+
396+
Returns:
397+
list[str]: List of error messages if any validation errors are found
398+
399+
Exits:
400+
Calls sys.exit(1) if the schema has validation errors
401+
"""
402+
spec_errors = validate_schema(schema)
403+
enum_errors = check_enum_defaults(schema)
404+
405+
all_errors: list[str] = []
406+
407+
if spec_errors:
408+
for spec_error in spec_errors:
409+
all_errors.append(f" - {spec_error.message}")
410+
411+
if enum_errors:
412+
for enum_error in enum_errors:
413+
all_errors.append(f" - {enum_error}")
414+
415+
return all_errors
416+
417+
271418
def ensure_query(schema: GraphQLSchema) -> GraphQLSchema:
272419
"""
273420
Ensures that the provided GraphQL schema has a Query type. If the schema does not have a Query type,
@@ -377,7 +524,7 @@ def visit_field_type(field_type: GraphQLType) -> None:
377524
return referenced
378525

379526

380-
def validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None:
527+
def _validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None:
381528
log.debug("Validating schema against the provided document")
382529

383530
errors = graphql_validate(schema, document)
@@ -406,7 +553,7 @@ def prune_schema_using_query_selection(schema: GraphQLSchema, document: Document
406553
if not schema.query_type:
407554
raise ValueError("Schema has no query type defined")
408555

409-
if validate_schema(schema, document) is None:
556+
if _validate_schema(schema, document) is None:
410557
raise ValueError("Schema validation failed")
411558

412559
fields_to_keep: dict[str, set[str]] = {}

tests/conftest.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,6 @@ class TestSchemaData:
4949
BREAKING_SCHEMA = TESTS_DATA_DIR / "breaking.graphql"
5050

5151

52-
@pytest.fixture(autouse=True)
53-
def patch_default_units_dir(monkeypatch: pytest.MonkeyPatch) -> None:
54-
"""Patch DEFAULT_QUDT_UNITS_DIR to use tests/data/units for all tests.
55-
56-
This prevents the "No QUDT units directory found" warning during tests
57-
and provides the necessary unit enum definitions that test schemas reference.
58-
59-
Tests that use the units_sync_mocks fixture will have this overridden with
60-
their own tmp_path directory for isolation.
61-
"""
62-
monkeypatch.setattr("s2dm.cli.DEFAULT_QUDT_UNITS_DIR", TestSchemaData.UNITS_SCHEMA_PATH)
63-
64-
6552
def parsed_console_output() -> str:
6653
"""Parse console output (placeholder function)."""
6754
return ""

0 commit comments

Comments
 (0)