Skip to content

Commit a49a698

Browse files
committed
Fix GraphQL interface MRO ordering for multiple inheritance
1 parent d1d6f4f commit a49a698

File tree

6 files changed

+255
-1
lines changed

6 files changed

+255
-1
lines changed

src/datamodel_code_generator/parser/graphql.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,50 @@ def parse_field(
421421
use_serialization_alias=self.use_serialization_alias,
422422
)
423423

424+
@staticmethod
425+
def _sort_interfaces_for_mro(
426+
interfaces: list[graphql.GraphQLInterfaceType],
427+
) -> list[graphql.GraphQLInterfaceType]:
428+
"""Sort interfaces so that subclasses come before their parent classes.
429+
430+
This ensures valid Python MRO (Method Resolution Order) when a class
431+
implements multiple interfaces where some interfaces extend others.
432+
433+
For example, if Notification implements Node, and a class implements
434+
both Node and Notification, the order should be [Notification, Node]
435+
not [Node, Notification].
436+
"""
437+
if len(interfaces) <= 1:
438+
return interfaces
439+
440+
# Build a set of all interface names for quick lookup
441+
interface_names = {i.name for i in interfaces}
442+
443+
# Get all ancestors for each interface (only considering interfaces in our list)
444+
def get_ancestors(iface: graphql.GraphQLInterfaceType) -> set[str]:
445+
"""Get all ancestor interface names that are in our interface list."""
446+
ancestors: set[str] = set()
447+
to_visit = list(getattr(iface, "interfaces", []))
448+
while to_visit:
449+
parent = to_visit.pop()
450+
if parent.name in interface_names and parent.name not in ancestors:
451+
ancestors.add(parent.name)
452+
to_visit.extend(getattr(parent, "interfaces", []))
453+
return ancestors
454+
455+
# Build ancestor map
456+
ancestor_map = {i.name: get_ancestors(i) for i in interfaces}
457+
458+
# Sort: interfaces with ancestors in the list should come before those ancestors
459+
# Use stable sort with custom key
460+
def sort_key(iface: graphql.GraphQLInterfaceType) -> tuple[int, str]:
461+
# Count how many other interfaces this one is an ancestor of
462+
# Interfaces that are ancestors of more others should come later
463+
ancestor_count = sum(1 for other in interfaces if iface.name in ancestor_map[other.name])
464+
return (ancestor_count, iface.name)
465+
466+
return sorted(interfaces, key=sort_key)
467+
424468
def parse_object_like(
425469
self,
426470
obj: graphql.GraphQLInterfaceType | graphql.GraphQLObjectType | graphql.GraphQLInputObjectType,
@@ -448,7 +492,8 @@ def parse_object_like(
448492

449493
base_classes = []
450494
if hasattr(obj, "interfaces"):
451-
base_classes = [self.references[i.name] for i in obj.interfaces] # ty: ignore
495+
sorted_interfaces = self._sort_interfaces_for_mro(list(obj.interfaces)) # ty: ignore
496+
base_classes = [self.references[i.name] for i in sorted_interfaces]
452497

453498
data_model_type = self._create_data_model(
454499
reference=self.references[obj.name],
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# generated by datamodel-codegen:
2+
# filename: interface_mro.graphql
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Annotated, Literal
8+
9+
from pydantic import BaseModel, Field
10+
from typing_extensions import TypeAliasType
11+
12+
Boolean = TypeAliasType("Boolean", bool)
13+
"""
14+
The `Boolean` scalar type represents `true` or `false`.
15+
"""
16+
17+
18+
ID = TypeAliasType("ID", str)
19+
"""
20+
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
21+
"""
22+
23+
24+
String = TypeAliasType("String", str)
25+
"""
26+
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
27+
"""
28+
29+
30+
class Entity(BaseModel):
31+
id: ID
32+
typename__: Annotated[Literal['Entity'] | None, Field(alias='__typename')] = (
33+
'Entity'
34+
)
35+
36+
37+
class Node(BaseModel):
38+
id: ID
39+
typename__: Annotated[Literal['Node'] | None, Field(alias='__typename')] = 'Node'
40+
41+
42+
class Notification(Node):
43+
id: ID
44+
read_at: Annotated[String | None, Field(alias='readAt')] = None
45+
typename__: Annotated[Literal['Notification'] | None, Field(alias='__typename')] = (
46+
'Notification'
47+
)
48+
49+
50+
class CustomerNeedNotification(Entity, Notification, Node):
51+
customer: String | None = None
52+
id: ID
53+
read_at: Annotated[String | None, Field(alias='readAt')] = None
54+
typename__: Annotated[
55+
Literal['CustomerNeedNotification'] | None, Field(alias='__typename')
56+
] = 'CustomerNeedNotification'
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# generated by datamodel-codegen:
2+
# filename: union_class_name_prefix.graphql
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Annotated, Literal, Union
8+
9+
from pydantic import BaseModel, Field
10+
from typing_extensions import TypeAliasType
11+
12+
FooBoolean = TypeAliasType("FooBoolean", bool)
13+
"""
14+
The `Boolean` scalar type represents `true` or `false`.
15+
"""
16+
17+
18+
FooID = TypeAliasType("FooID", str)
19+
"""
20+
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
21+
"""
22+
23+
24+
FooInt = TypeAliasType("FooInt", int)
25+
"""
26+
The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1.
27+
"""
28+
29+
30+
FooString = TypeAliasType("FooString", str)
31+
"""
32+
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
33+
"""
34+
35+
36+
class FooIResource(BaseModel):
37+
id: FooID
38+
typename__: Annotated[Literal['IResource'] | None, Field(alias='__typename')] = (
39+
'IResource'
40+
)
41+
42+
43+
class FooCar(FooIResource):
44+
id: FooID
45+
passenger_capacity: Annotated[FooInt, Field(alias='passengerCapacity')]
46+
typename__: Annotated[Literal['Car'] | None, Field(alias='__typename')] = 'Car'
47+
48+
49+
class FooEmployee(FooIResource):
50+
first_name: Annotated[FooString | None, Field(alias='firstName')] = None
51+
id: FooID
52+
last_name: Annotated[FooString | None, Field(alias='lastName')] = None
53+
typename__: Annotated[Literal['Employee'] | None, Field(alias='__typename')] = (
54+
'Employee'
55+
)
56+
57+
58+
FooResource = TypeAliasType(
59+
"FooResource",
60+
Union[
61+
'FooCar',
62+
'FooEmployee',
63+
],
64+
)
65+
66+
67+
FooTechnicalResource = TypeAliasType("FooTechnicalResource", FooCar)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
interface Entity {
2+
id: ID!
3+
}
4+
5+
interface Node {
6+
id: ID!
7+
}
8+
9+
interface Notification implements Node {
10+
id: ID!
11+
readAt: String
12+
}
13+
14+
type CustomerNeedNotification implements Entity & Node & Notification {
15+
id: ID!
16+
readAt: String
17+
customer: String
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
interface IResource {
2+
id: ID!
3+
}
4+
5+
type Employee implements IResource {
6+
id: ID!
7+
firstName: String
8+
lastName: String
9+
}
10+
11+
type Car implements IResource {
12+
id: ID!
13+
passengerCapacity: Int!
14+
}
15+
16+
union Resource = Employee | Car
17+
18+
union TechnicalResource = Car

tests/main/graphql/test_main_graphql.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,31 @@ def test_main_graphql_class_name_prefix(output_file: Path) -> None:
777777
)
778778

779779

780+
def test_main_graphql_union_class_name_prefix(output_file: Path) -> None:
781+
"""Test that union type members get class name prefix applied.
782+
783+
When using --class-name-prefix, the prefix should be applied to both
784+
the union type name and all its member type references.
785+
786+
This test verifies fix for issue #2939.
787+
"""
788+
run_main_and_assert(
789+
input_path=GRAPHQL_DATA_PATH / "union_class_name_prefix.graphql",
790+
output_path=output_file,
791+
input_file_type="graphql",
792+
assert_func=assert_file_content,
793+
expected_file="union_class_name_prefix.py",
794+
extra_args=[
795+
"--output-model-type",
796+
"pydantic_v2.BaseModel",
797+
"--class-name-prefix",
798+
"Foo",
799+
"--use-annotated",
800+
"--snake-case-field",
801+
],
802+
)
803+
804+
780805
def test_main_graphql_union_snake_case_field(output_file: Path) -> None:
781806
"""Test that union type references are not converted to snake_case."""
782807
run_main_and_assert(
@@ -789,6 +814,31 @@ def test_main_graphql_union_snake_case_field(output_file: Path) -> None:
789814
)
790815

791816

817+
def test_main_graphql_interface_mro(output_file: Path) -> None:
818+
"""Test that interface inheritance is ordered correctly for Python MRO.
819+
820+
When a class implements multiple interfaces where some interfaces extend others,
821+
the base classes must be ordered so that subclasses come before their parent classes.
822+
For example, if Notification implements Node, and a class implements both
823+
Node and Notification, the order should be (Notification, Node) not (Node, Notification).
824+
825+
This test verifies fix for issue #2938.
826+
"""
827+
run_main_and_assert(
828+
input_path=GRAPHQL_DATA_PATH / "interface_mro.graphql",
829+
output_path=output_file,
830+
input_file_type="graphql",
831+
assert_func=assert_file_content,
832+
expected_file="interface_mro.py",
833+
extra_args=[
834+
"--output-model-type",
835+
"pydantic_v2.BaseModel",
836+
"--use-annotated",
837+
"--snake-case-field",
838+
],
839+
)
840+
841+
792842
def test_main_graphql_split_graphql_schemas(output_file: Path) -> None:
793843
"""Test GraphQL code generation with multiple schema files in a directory."""
794844
run_main_and_assert(

0 commit comments

Comments
 (0)