Skip to content

Commit b6c88ee

Browse files
committed
fix coverage and add overloads
1 parent d11d00c commit b6c88ee

File tree

2 files changed

+60
-12
lines changed

2 files changed

+60
-12
lines changed

gql/dsl.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,19 @@
77
import re
88
from abc import ABC, abstractmethod
99
from math import isfinite
10-
from typing import Any, Dict, Iterable, Mapping, Optional, Set, Tuple, Union, cast
10+
from typing import (
11+
Any,
12+
Dict,
13+
Iterable,
14+
Literal,
15+
Mapping,
16+
Optional,
17+
Set,
18+
Tuple,
19+
Union,
20+
cast,
21+
overload,
22+
)
1123

1224
from graphql import (
1325
ArgumentNode,
@@ -297,28 +309,48 @@ def __init__(self, schema: GraphQLSchema):
297309

298310
self._schema: GraphQLSchema = schema
299311

300-
def __call__(self, name: str) -> "DSLDirective":
312+
@overload
313+
def __call__(
314+
self, shortcut: Literal["__typename", "__schema", "__type"]
315+
) -> "DSLMetaField": ... # pragma: no cover
316+
317+
@overload
318+
def __call__(
319+
self, shortcut: Literal["..."]
320+
) -> "DSLInlineFragment": ... # pragma: no cover
321+
322+
@overload
323+
def __call__(
324+
self, shortcut: Literal["fragment"], name: str
325+
) -> "DSLFragment": ... # pragma: no cover
326+
327+
@overload
328+
def __call__(self, shortcut: Any) -> "DSLDirective": ... # pragma: no cover
329+
330+
def __call__(
331+
self, shortcut: str, name: Optional[str] = None
332+
) -> Union["DSLMetaField", "DSLInlineFragment", "DSLFragment", "DSLDirective"]:
301333
"""Factory method for creating DSL objects.
302334
303335
Currently, supports creating DSLDirective instances when name starts with '@'.
304336
Future support planned for meta-fields (__typename), inline fragments (...),
305337
and fragment definitions (fragment).
306338
307-
:param name: the name of the object to create
308-
:type name: str
339+
:param shortcut: the name of the object to create
340+
:type shortcut: LiteralString
309341
310342
:return: DSLDirective instance
311343
312-
:raises ValueError: if name format is not supported
344+
:raises ValueError: if shortcut format is not supported
313345
"""
314-
if name.startswith("@"):
315-
return DSLDirective(name=name[1:], dsl_schema=self)
346+
if shortcut.startswith("@"):
347+
return DSLDirective(name=shortcut[1:], dsl_schema=self)
316348
# Future support:
317349
# if name.startswith("__"): return DSLMetaField(name)
318350
# if name == "...": return DSLInlineFragment()
319351
# if name.startswith("fragment "): return DSLFragment(name[9:])
320352

321-
raise ValueError(f"Unsupported name: {name}")
353+
raise ValueError(f"Unsupported shortcut: {shortcut}")
322354

323355
def __getattr__(self, name: str) -> "DSLType":
324356

@@ -549,7 +581,7 @@ def is_valid_directive(self, directive: "DSLDirective") -> bool:
549581
"""
550582
raise NotImplementedError(
551583
"Any DSLDirectable concrete class must have an is_valid_directive method"
552-
)
584+
) # pragma: no cover
553585

554586
def directives(self, *directives: DSLDirective) -> Any:
555587
r"""Add directives to this DSL element.

tests/starwars/test_dsl.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,11 @@ def test_legacy_fragment_with_variables(ds):
12971297
assert print_ast(query.document) == expected
12981298

12991299

1300+
def test_dsl_schema_call_validation(ds):
1301+
with pytest.raises(ValueError, match="(?i)unsupported shortcut"):
1302+
ds("foo")
1303+
1304+
13001305
def test_executable_directives(ds, var):
13011306
"""Test ALL executable directive locations and types in one document"""
13021307

@@ -1424,14 +1429,26 @@ def test_directive_error_handling(ds):
14241429
with pytest.raises(TypeError, match="Expected DSLDirective"):
14251430
ds.Query.hero.directives(123)
14261431

1427-
# Invalid directive name
1432+
# Invalid directive name from `__call__
14281433
with pytest.raises(GraphQLError, match="Directive '@nonexistent' not found"):
14291434
ds("@nonexistent")
14301435

14311436
# Invalid directive argument
14321437
with pytest.raises(GraphQLError, match="Argument 'invalid' does not exist"):
14331438
ds("@include")(invalid=True)
14341439

1440+
# Tried to set arguments twice
1441+
with pytest.raises(
1442+
AttributeError, match="Arguments for directive @field already set."
1443+
):
1444+
ds("@field").args(value="foo").args(value="bar")
1445+
1446+
with pytest.raises(
1447+
GraphQLError,
1448+
match="(?i)Directive '@deprecated' is not a valid request executable directive",
1449+
):
1450+
ds("@deprecated")
1451+
14351452
with pytest.raises(GraphQLError, match="unexpected variable"):
14361453
# variable definitions must be static, literal values defined in the query!
14371454
var = DSLVariableDefinitions()
@@ -1442,8 +1459,7 @@ def test_directive_error_handling(ds):
14421459
ds("@variableDefinition").args(value=var.nonStatic),
14431460
)
14441461
query.variable_definitions = var
1445-
invalid = print_ast(dsl_gql(query).document)
1446-
print(invalid)
1462+
_ = dsl_gql(query).document
14471463

14481464

14491465
# Parametrized tests for comprehensive directive location validation

0 commit comments

Comments
 (0)