1
1
package com .introproventures .graphql .jpa .query .autoconfigure ;
2
2
3
+ import static graphql .Assert .assertTrue ;
4
+ import static graphql .schema .FieldCoordinates .coordinates ;
5
+ import static graphql .util .TraversalControl .CONTINUE ;
6
+
7
+ import java .util .Collection ;
3
8
import java .util .List ;
4
9
import java .util .Objects ;
10
+ import java .util .Optional ;
5
11
import java .util .stream .Collectors ;
6
12
import java .util .stream .Stream ;
7
13
8
14
import org .springframework .beans .factory .config .AbstractFactoryBean ;
9
15
16
+ import graphql .Internal ;
17
+ import graphql .schema .DataFetcher ;
18
+ import graphql .schema .FieldCoordinates ;
19
+ import graphql .schema .GraphQLCodeRegistry ;
10
20
import graphql .schema .GraphQLFieldDefinition ;
21
+ import graphql .schema .GraphQLFieldsContainer ;
22
+ import graphql .schema .GraphQLInterfaceType ;
11
23
import graphql .schema .GraphQLObjectType ;
12
24
import graphql .schema .GraphQLSchema ;
25
+ import graphql .schema .GraphQLType ;
26
+ import graphql .schema .GraphQLTypeVisitorStub ;
27
+ import graphql .schema .GraphQLUnionType ;
28
+ import graphql .schema .PropertyDataFetcher ;
29
+ import graphql .schema .TypeResolver ;
30
+ import graphql .schema .TypeTraverser ;
31
+ import graphql .util .TraversalControl ;
32
+ import graphql .util .TraverserContext ;
13
33
14
34
public class GraphQLSchemaFactoryBean extends AbstractFactoryBean <GraphQLSchema >{
15
-
16
- private static final String QUERY_NAME = "Query" ;
35
+
36
+ private static final String QUERY_NAME = "Query" ;
17
37
private static final String QUERY_DESCRIPTION = "" ;
18
38
private static final String SUBSCRIPTION_NAME = "Subscription" ;
19
39
private static final String SUBSCRIPTION_DESCRIPTION = "" ;
@@ -22,85 +42,118 @@ public class GraphQLSchemaFactoryBean extends AbstractFactoryBean<GraphQLSchema>
22
42
23
43
24
44
private final GraphQLSchema [] managedGraphQLSchemas ;
25
-
26
- private String queryName = QUERY_NAME ;
27
- private String queryDescription = QUERY_DESCRIPTION ;
45
+
46
+ private String queryName = QUERY_NAME ;
47
+ private String queryDescription = QUERY_DESCRIPTION ;
28
48
29
49
private String subscriptionName = SUBSCRIPTION_NAME ;
30
50
private String subscriptionDescription = SUBSCRIPTION_DESCRIPTION ;
31
51
32
52
private String mutationName = MUTATION_NAME ;
33
53
private String mutationDescription = MUTATION_DESCRIPTION ;
34
54
35
-
36
- public GraphQLSchemaFactoryBean (GraphQLSchema [] managedGraphQLSchemas ) {
37
- this .managedGraphQLSchemas = managedGraphQLSchemas ;
38
- }
39
-
40
- @ Override
41
- protected GraphQLSchema createInstance () throws Exception {
42
-
43
- GraphQLSchema .Builder schemaBuilder = GraphQLSchema .newSchema ();
44
-
45
- List <GraphQLFieldDefinition > mutations = Stream .of (managedGraphQLSchemas )
46
- .map (GraphQLSchema ::getMutationType )
47
- .filter (Objects ::nonNull )
48
- .map (GraphQLObjectType ::getFieldDefinitions )
49
- .flatMap (children -> children .stream ())
50
- .collect (Collectors .toList ());
51
-
52
- List <GraphQLFieldDefinition > queries = Stream .of (managedGraphQLSchemas )
53
- .map (GraphQLSchema ::getQueryType )
54
- .filter (Objects ::nonNull )
55
- .filter (it -> !it .getName ().equals ("null" )) // filter out null placeholders
56
- .map (GraphQLObjectType ::getFieldDefinitions )
57
- .flatMap (children -> children .stream ())
58
- .collect (Collectors .toList ());
59
-
60
- List <GraphQLFieldDefinition > subscriptions = Stream .of (managedGraphQLSchemas )
61
- .map (GraphQLSchema ::getSubscriptionType )
62
- .filter (Objects ::nonNull )
63
- .map (GraphQLObjectType ::getFieldDefinitions )
64
- .flatMap (children -> children .stream ())
65
- .collect (Collectors .toList ());
66
-
67
- if (!mutations .isEmpty ())
68
- schemaBuilder .mutation (GraphQLObjectType .newObject ()
55
+
56
+ public GraphQLSchemaFactoryBean (GraphQLSchema [] managedGraphQLSchemas ) {
57
+ this .managedGraphQLSchemas = managedGraphQLSchemas ;
58
+ }
59
+
60
+ @ Override
61
+ protected GraphQLSchema createInstance () throws Exception {
62
+
63
+ GraphQLSchema .Builder schemaBuilder = GraphQLSchema .newSchema ();
64
+ GraphQLCodeRegistry .Builder codeRegistryBuilder = GraphQLCodeRegistry .newCodeRegistry ();
65
+ TypeTraverser typeTraverser = new TypeTraverser ();
66
+
67
+ List <GraphQLFieldDefinition > mutations = Stream .of (managedGraphQLSchemas )
68
+ .filter (it -> it .getMutationType () != null )
69
+ .peek (schema -> {
70
+ schema .getCodeRegistry ().transform (builderConsumer -> {
71
+ typeTraverser .depthFirst (new CodeRegistryVisitor (builderConsumer ,
72
+ codeRegistryBuilder ,
73
+ schema .getMutationType (),
74
+ mutationName ),
75
+ schema .getMutationType ());
76
+ });
77
+ })
78
+ .map (GraphQLSchema ::getMutationType )
79
+ .filter (Objects ::nonNull )
80
+ .map (GraphQLObjectType ::getFieldDefinitions )
81
+ .flatMap (Collection ::stream )
82
+ .collect (Collectors .toList ());
83
+
84
+ List <GraphQLFieldDefinition > queries = Stream .of (managedGraphQLSchemas )
85
+ .filter (it -> Optional .ofNullable (it .getQueryType ())
86
+ .map (GraphQLType ::getName )
87
+ .filter (name -> !"null" .equals (name )) // filter out null placeholders
88
+ .isPresent ())
89
+ .peek (schema -> {
90
+ schema .getCodeRegistry ().transform (builderConsumer -> {
91
+ typeTraverser .depthFirst (new CodeRegistryVisitor (builderConsumer ,
92
+ codeRegistryBuilder ,
93
+ schema .getQueryType (),
94
+ queryName ),
95
+ schema .getQueryType ());
96
+ });
97
+ })
98
+ .map (GraphQLSchema ::getQueryType )
99
+ .map (GraphQLObjectType ::getFieldDefinitions )
100
+ .flatMap (Collection ::stream )
101
+ .collect (Collectors .toList ());
102
+
103
+ List <GraphQLFieldDefinition > subscriptions = Stream .of (managedGraphQLSchemas )
104
+ .filter (it -> it .getSubscriptionType () != null )
105
+ .peek (schema -> {
106
+ schema .getCodeRegistry ().transform (builderConsumer -> {
107
+ typeTraverser .depthFirst (new CodeRegistryVisitor (builderConsumer ,
108
+ codeRegistryBuilder ,
109
+ schema .getSubscriptionType (),
110
+ subscriptionName ),
111
+ schema .getSubscriptionType ());
112
+ });
113
+ })
114
+ .map (GraphQLSchema ::getSubscriptionType )
115
+ .map (GraphQLObjectType ::getFieldDefinitions )
116
+ .flatMap (Collection ::stream )
117
+ .collect (Collectors .toList ());
118
+
119
+ if (!mutations .isEmpty ())
120
+ schemaBuilder .mutation (GraphQLObjectType .newObject ()
69
121
.name (this .mutationName )
70
122
.description (this .mutationDescription )
71
- .fields (mutations ));
123
+ .fields (mutations ));
72
124
73
- if (!queries .isEmpty ())
74
- schemaBuilder .query (GraphQLObjectType .newObject ()
75
- .name (this .queryName )
76
- .description (this .queryDescription )
77
- .fields (queries ));
125
+ if (!queries .isEmpty ())
126
+ schemaBuilder .query (GraphQLObjectType .newObject ()
127
+ .name (this .queryName )
128
+ .description (this .queryDescription )
129
+ .fields (queries ));
78
130
79
- if (!subscriptions .isEmpty ())
80
- schemaBuilder .subscription (GraphQLObjectType .newObject ()
131
+ if (!subscriptions .isEmpty ())
132
+ schemaBuilder .subscription (GraphQLObjectType .newObject ()
81
133
.name (this .subscriptionName )
82
134
.description (this .subscriptionDescription )
83
- .fields (subscriptions ));
84
-
85
- return schemaBuilder .build ();
86
- }
87
-
88
- @ Override
89
- public Class <?> getObjectType () {
90
- return GraphQLSchema .class ;
91
- }
92
-
93
- public GraphQLSchemaFactoryBean setQueryName (String name ) {
94
- this .queryName = name ;
95
-
96
- return this ;
97
- }
98
-
99
- public GraphQLSchemaFactoryBean setQueryDescription (String description ) {
100
- this .queryDescription = description ;
101
-
102
- return this ;
103
- }
135
+ .fields (subscriptions ));
136
+
137
+ return schemaBuilder .codeRegistry (codeRegistryBuilder .build ())
138
+ .build ();
139
+ }
140
+
141
+ @ Override
142
+ public Class <?> getObjectType () {
143
+ return GraphQLSchema .class ;
144
+ }
145
+
146
+ public GraphQLSchemaFactoryBean setQueryName (String name ) {
147
+ this .queryName = name ;
148
+
149
+ return this ;
150
+ }
151
+
152
+ public GraphQLSchemaFactoryBean setQueryDescription (String description ) {
153
+ this .queryDescription = description ;
154
+
155
+ return this ;
156
+ }
104
157
105
158
public GraphQLSchemaFactoryBean setSubscriptionName (String subscriptionName ) {
106
159
this .subscriptionName = subscriptionName ;
@@ -125,5 +178,65 @@ public GraphQLSchemaFactoryBean setMutationDescription(String mutationDescriptio
125
178
126
179
return this ;
127
180
}
128
-
181
+
182
+ /**
183
+ * This ensure that all fields have data fetchers and that unions and interfaces have type resolvers
184
+ */
185
+ @ Internal
186
+ class CodeRegistryVisitor extends GraphQLTypeVisitorStub {
187
+ private final GraphQLCodeRegistry .Builder source ;
188
+ private final GraphQLCodeRegistry .Builder codeRegistry ;
189
+ private final GraphQLFieldsContainer containerType ;
190
+ private final String typeName ;
191
+
192
+ CodeRegistryVisitor (GraphQLCodeRegistry .Builder context ,
193
+ GraphQLCodeRegistry .Builder codeRegistry ,
194
+ GraphQLFieldsContainer containerType ,
195
+ String typeName ) {
196
+ this .source = context ;
197
+ this .codeRegistry = codeRegistry ;
198
+ this .containerType = containerType ;
199
+ this .typeName = typeName ;
200
+ }
201
+
202
+ @ Override
203
+ public TraversalControl visitGraphQLFieldDefinition (GraphQLFieldDefinition node , TraverserContext <GraphQLType > context ) {
204
+ GraphQLFieldsContainer parentContainerType = (GraphQLFieldsContainer ) context .getParentContext ().thisNode ();
205
+ FieldCoordinates coordinates = parentContainerType .equals (containerType ) ? coordinates (typeName , node .getName ())
206
+ : coordinates (parentContainerType , node );
207
+
208
+ DataFetcher <?> dataFetcher = source .getDataFetcher (parentContainerType ,
209
+ node );
210
+ if (dataFetcher == null ) {
211
+ dataFetcher = new PropertyDataFetcher <>(node .getName ());
212
+ }
213
+
214
+ codeRegistry .dataFetcherIfAbsent (coordinates ,
215
+ dataFetcher );
216
+ return CONTINUE ;
217
+ }
218
+
219
+ @ Override
220
+ public TraversalControl visitGraphQLInterfaceType (GraphQLInterfaceType node , TraverserContext <GraphQLType > context ) {
221
+ TypeResolver typeResolver = codeRegistry .getTypeResolver (node );
222
+
223
+ if (typeResolver != null ) {
224
+ codeRegistry .typeResolverIfAbsent (node , typeResolver );
225
+ }
226
+ assertTrue (codeRegistry .getTypeResolver (node ) != null , "You MUST provide a type resolver for the interface type '" + node .getName () + "'" );
227
+ return CONTINUE ;
228
+ }
229
+
230
+ @ Override
231
+ public TraversalControl visitGraphQLUnionType (GraphQLUnionType node , TraverserContext <GraphQLType > context ) {
232
+ TypeResolver typeResolver = codeRegistry .getTypeResolver (node );
233
+ if (typeResolver != null ) {
234
+ codeRegistry .typeResolverIfAbsent (node , typeResolver );
235
+ }
236
+ assertTrue (codeRegistry .getTypeResolver (node ) != null , "You MUST provide a type resolver for the union type '" + node .getName () + "'" );
237
+ return CONTINUE ;
238
+ }
239
+ }
240
+
241
+
129
242
}
0 commit comments