Skip to content

Commit 2ebf233

Browse files
authored
Merge pull request #54 from LiUGraphQL/directives
Directives handling for #47, #48, #49, #50 (followed up in #61)
2 parents d39d381 + 4fc9707 commit 2ebf233

File tree

5 files changed

+502
-78
lines changed

5 files changed

+502
-78
lines changed

graphql-api-generator/generator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,20 @@ def cmd(args):
3636
with open(file, 'r') as f:
3737
schema_string += f.read() + '\n'
3838
schema = build_schema(schema_string)
39-
39+
4040
# run
4141
schema = run(schema, config)
4242

4343
# write to file or stdout
4444
if args.output:
4545
with open(args.output, 'w') as out:
46-
out.write(print_schema(schema))
46+
out.write(print_schema_with_directives(schema))
4747
else:
48-
print(print_schema(schema))
48+
print(print_schema_with_directives(schema))
4949

5050

5151
def run(schema: GraphQLSchema, config: dict):
52+
5253
# validate
5354
if config.get('validate'):
5455
validate_names(schema, config.get('validate'))
@@ -268,7 +269,8 @@ def datetime_control(schema):
268269
if not is_scalar_type(schema.type_map['DateTime']):
269270
raise Exception('DateTime exists but is not scalar type: ' + schema.type_map['DateTime'])
270271
else:
271-
schema.type_map['DateTime'] = GraphQLScalarType('DateTime')
272+
# ast_node definition ensures that DateTime appears as a user-defined scalar
273+
schema.type_map['DateTime'] = GraphQLScalarType('DateTime', ast_node=ScalarTypeDefinitionNode())
272274
if not is_scalar_type(schema.type_map['DateTime']):
273275
raise Exception('DateTime could not be added as scalar!')
274276

graphql-api-generator/utils/utils.py

Lines changed: 295 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,27 @@ def add_reverse_edges(schema: GraphQLSchema):
202202
# Reverse edge
203203
edge_from = get_named_type(field_type.type)
204204
edge_name = f'_{field_name}From{_type.name}'
205-
edge_to = GraphQLList(_type)
205+
206+
directives = {}
207+
directive_to_add = ''
208+
209+
if hasattr(field_type, 'ast_node') and field_type.ast_node is not None:
210+
directives = {directive.name.value: directive for directive in field_type.ast_node.directives}
211+
212+
if 'requiredForTarget' in directives:
213+
directive_to_add = '@required'
214+
215+
if 'uniqueForTarget' in directives:
216+
edge_to = _type
217+
else:
218+
edge_to = GraphQLList(_type)
206219

207220
if is_interface_type(edge_from):
208-
make += 'extend interface {0} {{ {1}: {2} }}\n'.format(edge_from, edge_name, edge_to)
221+
make += 'extend interface {0} {{ {1}: {2} {3} }}\n'.format(edge_from, edge_name, edge_to, directive_to_add)
209222
for implementing_type in schema.get_possible_types(edge_from):
210223
make += 'extend type {0} {{ {1}: {2} }}\n'.format(implementing_type, edge_name, edge_to)
211224
else:
212-
make += 'extend type {0} {{ {1}: {2} }}\n'.format(edge_from, edge_name, edge_to)
225+
make += 'extend type {0} {{ {1}: {2} {3} }}\n'.format(edge_from, edge_name, edge_to, directive_to_add)
213226
schema = add_to_schema(schema, make)
214227

215228
return schema
@@ -234,19 +247,21 @@ def add_input_to_create(schema: GraphQLSchema):
234247
for _type in schema.type_map.values():
235248
if not is_db_schema_defined_type(_type) or is_interface_type(_type):
236249
continue
237-
make += f'\nextend input _InputToCreate{_type.name} {{ '
250+
make += f'\nextend input _InputToCreate{_type.name} {{\n'
238251
for field_name, field in _type.fields.items():
239252
if field_name == 'id' or field_name[0] == '_':
240253
continue
254+
241255
inner_field_type = get_named_type(field.type)
256+
242257
if is_enum_or_scalar(inner_field_type):
243-
make += f'{field_name}: {field.type} '
258+
make += f' {field_name}: {field.type} \n'
244259
else:
245260
schema = extend_connect(schema, _type, inner_field_type, field_name)
246261
connect_name = f'_InputToConnect{capitalize(field_name)}Of{_type.name}'
247262
connect = copy_wrapper_structure(schema.type_map[connect_name], field.type)
248-
make += f' {field_name}: {connect} '
249-
make += '} '
263+
make += f' {field_name}: {connect} \n'
264+
make += '}\n'
250265
schema = add_to_schema(schema, make)
251266
return schema
252267

@@ -373,12 +388,12 @@ def add_input_update(schema: GraphQLSchema):
373388
inner_field_type = get_named_type(f_type)
374389

375390
if is_enum_or_scalar(inner_field_type):
376-
make += f'extend input {update_name} {{ {field_name}: {f_type} }} '
391+
make += f'extend input {update_name} {{ {field_name}: {f_type} }} \n'
377392
else:
378393
# add create or connect field
379394
connect_name = f'_InputToConnect{capitalize(field_name)}Of{_type.name}'
380395
connect = copy_wrapper_structure(schema.get_type(connect_name), f_type)
381-
make += f'extend input {update_name} {{ {field_name}: {connect} }} '
396+
make += f'extend input {update_name} {{ {field_name}: {connect} }} \n'
382397
schema = add_to_schema(schema, make)
383398
return schema
384399

@@ -751,3 +766,274 @@ def add_delete_mutations(schema: GraphQLSchema):
751766
make += f'extend type Mutation {{ {delete}(id: ID!): {_type.name} }} '
752767
schema = add_to_schema(schema, make)
753768
return schema
769+
770+
771+
def ast_type_to_string(_type: GraphQLType):
772+
"""
773+
Print the ast_type properly
774+
:param _type:
775+
:return:
776+
"""
777+
778+
# ast_nodes types behavies differently than other types (as they are NodeTypes)
779+
# So we can't use the normal functions
780+
781+
782+
_post_str = ''
783+
_pre_str = ''
784+
# A, A!, [A!], [A]!, [A!]!
785+
wrappers = []
786+
if isinstance(_type, NonNullTypeNode):
787+
_post_str = '!'
788+
_type = _type.type
789+
if isinstance(_type, ListTypeNode):
790+
_post_str = ']' + _post_str
791+
_pre_str = '['
792+
_type = _type.type
793+
if isinstance(_type, NonNullTypeNode):
794+
_post_str = '!' + _post_str
795+
_type = _type.type
796+
797+
# Dig down to find the actual named node, should be the first one actually
798+
name = _type
799+
while not isinstance(name, NamedTypeNode):
800+
name = name.type
801+
name = name.name.value
802+
803+
return _pre_str + name + _post_str
804+
805+
806+
def directive_from_interface(directive, interface_name):
807+
"""
808+
Return the correct directive string from directives inhertied from interfaces
809+
:param directive:
810+
:param interface_name:
811+
:return string:
812+
"""
813+
directive_string = directive.name.value
814+
815+
# The only two cases who needs special attention is @requiredForTarget and @uniqueForTarget
816+
if directive_string == 'requiredForTarget':
817+
directive_string = '_requiredForTarget_AccordingToInterface(interface: "' + interface_name + '")'
818+
elif directive_string == 'uniqueForTarget':
819+
directive_string = '_uniqueForTarget_AccordingToInterface(interface: "' + interface_name + '")'
820+
else:
821+
directive_string += get_directive_arguments(directive)
822+
823+
return directive_string
824+
825+
826+
def get_directive_arguments(directive):
827+
"""
828+
Get the arguments of the given directive as string
829+
:param directive:
830+
:return string:
831+
"""
832+
833+
output = ''
834+
if directive.arguments:
835+
output+= '('
836+
for arg in directive.arguments:
837+
output+= arg.name.value + ':'
838+
if isinstance(arg.value, ListValueNode):
839+
# List
840+
output+= '['
841+
for V in arg.value.values:
842+
if isinstance(V, StringValueNode):
843+
output+='"' + V.value + '", '
844+
else:
845+
output+= V.value + ', '
846+
847+
output = output[:-2] + ']'
848+
849+
else:
850+
# Non-list
851+
if isinstance(arg.value, StringValueNode):
852+
output+='"' + arg.value.value + '", '
853+
else:
854+
output+= arg.value.value + ', '
855+
856+
output += ', '
857+
858+
output = output[:-2] + ')'
859+
860+
return output
861+
862+
863+
def get_field_directives(field_name, _type, schema):
864+
"""
865+
Get the directives of given field, and return them as string
866+
:param field:
867+
:param field_name:
868+
:param _type:
869+
:param schema:
870+
:return string:
871+
"""
872+
873+
output = ''
874+
875+
# Used to make sure we don't add the same directive multiple times to the same field
876+
directives_set = set()
877+
878+
if is_input_type(_type):
879+
# Get the target type instead (unless it is a filter or delete input, then we dont care)
880+
# We also ignore @required directives for inputs
881+
if _type.name[:14] == '_InputToUpdate':
882+
directives_set.add('required')
883+
_type = schema.get_type(_type.name[14:])
884+
885+
elif _type.name[:14] == '_InputToCreate':
886+
_type = schema.get_type(_type.name[14:])
887+
directives_set.add('required')
888+
889+
else:
890+
return ''
891+
892+
# We got type without fields, just return empty
893+
if not hasattr(_type, 'fields'):
894+
return ''
895+
896+
# Get the field from the correct type
897+
field = _type.fields[field_name]
898+
899+
# Get all directives directly on field
900+
for directive in field.ast_node.directives:
901+
if not directive.name.value in directives_set:
902+
output+= ' @' + directive.name.value
903+
directives_set.add(directive.name.value)
904+
output += get_directive_arguments(directive)
905+
906+
907+
if hasattr(_type, 'interfaces'):
908+
# Get all inherited directives
909+
for interface in _type.interfaces:
910+
if field_name in interface.fields:
911+
for directive in interface.fields[field_name].ast_node.directives:
912+
directive_str = directive_from_interface(directive, interface.name)
913+
if not directive_str in directives_set:
914+
output+= ' @' + directive_str
915+
directives_set.add(directive_str)
916+
917+
return output
918+
919+
920+
def get_type_directives(_type, schema):
921+
"""
922+
Get the directives of given type, or target type if create- or update-input
923+
:param type:
924+
:return string:
925+
"""
926+
927+
output = ''
928+
929+
if is_input_type(_type):
930+
# Get the target type instead (unless it is a filter or delete input, then we dont care)
931+
if _type.name[:14] == '_InputToUpdate':
932+
_type = schema.get_type(_type.name[14:])
933+
934+
elif _type.name[:14] == '_InputToCreate':
935+
_type = schema.get_type(_type.name[14:])
936+
else:
937+
return ''
938+
939+
if hasattr(_type, 'ast_node') and _type.ast_node is not None:
940+
# Get directives on type
941+
for directive in _type.ast_node.directives:
942+
output+= ' @' + directive.name.value
943+
output += get_directive_arguments(directive)
944+
945+
return output
946+
947+
948+
def print_schema_with_directives(schema):
949+
"""
950+
Outputs the given schema as string, in the format we want it.
951+
Types and fields will all contain directives
952+
:param schema:
953+
:return string:
954+
"""
955+
manual_directives = {
956+
'required': 'directive @required on FIELD_DEFINITION',
957+
'key': 'directive @key(fields: [String!]!) on OBJECT | INPUT_OBJECT',
958+
'distinct': 'directive @distinct on FIELD_DEFINITION | INPUT_FIELD_DEFINITION',
959+
'noloops': 'directive @noloops on FIELD_DEFINITION | INPUT_FIELD_DEFINITION',
960+
'requiredForTarget': 'directive @requiredForTarget on FIELD_DEFINITION | INPUT_FIELD_DEFINITION',
961+
'uniqueForTarget': 'directive @uniqueForTarget on FIELD_DEFINITION | INPUT_FIELD_DEFINITION',
962+
'_requiredForTarget_AccordingToInterface': 'directive @_requiredForTarget_AccordingToInterface(interface: String!) on FIELD_DEFINITION | INPUT_FIELD_DEFINITION',
963+
'_uniqueForTarget_AccordingToInterface': 'directive @_uniqueForTarget_AccordingToInterface(interface: String!) on FIELD_DEFINITION | INPUT_FIELD_DEFINITION'
964+
}
965+
output = ''
966+
# Add directives
967+
for _dir in schema.directives:
968+
# Skip non-user defined directives
969+
if _dir.ast_node is None or _dir.name in manual_directives.keys():
970+
continue
971+
972+
output += f'directive @{_dir.name}'
973+
if _dir.ast_node.arguments:
974+
args = ', '.join([f'{arg.name.value}: {ast_type_to_string(arg.type)}' for arg in _dir.ast_node.arguments])
975+
output += f'({args})'
976+
977+
output += ' on ' + ' | '.join([loc.name for loc in _dir.locations])
978+
output += '\n\n'
979+
980+
# Manually handled directives
981+
for _dir in manual_directives.values():
982+
output += _dir + '\n\n'
983+
984+
# For each type, and output the types sorted by name
985+
for _type in sorted(schema.type_map.values(), key=lambda x: x.name):
986+
# Internal type
987+
if _type.name.startswith('__'):
988+
continue
989+
990+
if is_interface_type(_type):
991+
output += 'interface ' + _type.name
992+
elif is_enum_type(_type):
993+
output += 'enum ' + _type.name
994+
elif is_scalar_type(_type):
995+
# Skip non-user defined directives
996+
if _type.ast_node is not None:
997+
output += 'scalar ' + _type.name
998+
elif is_input_type(_type):
999+
output += 'input ' + _type.name
1000+
else:
1001+
output += 'type ' + _type.name
1002+
if hasattr(_type, 'interfaces') and _type.interfaces:
1003+
output += ' implements '
1004+
output += ' & '.join([interface.name for interface in _type.interfaces])
1005+
1006+
if is_enum_type(_type):
1007+
# For enums we can get the values directly and add them
1008+
output += ' {\n'
1009+
for value in _type.values:
1010+
output += ' ' + value + '\n'
1011+
output += '}'
1012+
1013+
elif not is_enum_or_scalar(_type):
1014+
# This should be a type, or an interface
1015+
# Get directives on type
1016+
output += get_type_directives(_type, schema)
1017+
output += ' {\n'
1018+
1019+
# Get fields
1020+
for field_name, field in _type.fields.items():
1021+
output += ' ' + field_name
1022+
1023+
# Get arguments for field
1024+
if hasattr(field, 'args') and field.args:
1025+
args = ', '.join([f'{arg_name}: {arg.type}' for arg_name, arg in field.args.items()])
1026+
output += f'({args})'
1027+
1028+
output += ': ' + str(field.type)
1029+
1030+
# Add directives
1031+
output += get_field_directives(field_name, _type, schema)
1032+
output += '\n'
1033+
1034+
output += '}'
1035+
1036+
if _type.ast_node is not None:
1037+
output += '\n\n'
1038+
1039+
return output

0 commit comments

Comments
 (0)