Skip to content

Commit 8c9661d

Browse files
authored
Merge pull request #166 from arunsureshkumar/support-point-field-in-args
Support point field in args
2 parents b23cbe3 + 1b4913f commit 8c9661d

File tree

4 files changed

+46
-26
lines changed

4 files changed

+46
-26
lines changed

graphene_mongo/advanced_types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44

55
class FileFieldType(graphene.ObjectType):
6-
76
content_type = graphene.String()
87
md5 = graphene.String()
98
chunk_size = graphene.Int()
@@ -36,7 +35,6 @@ def resolve_data(self, info):
3635

3736

3837
class _CoordinatesTypeField(graphene.ObjectType):
39-
4038
type = graphene.String()
4139

4240
def resolve_type(self, info):
@@ -47,17 +45,19 @@ def resolve_coordinates(self, info):
4745

4846

4947
class PointFieldType(_CoordinatesTypeField):
50-
5148
coordinates = graphene.List(graphene.Float)
5249

5350

54-
class PolygonFieldType(_CoordinatesTypeField):
51+
class PointFieldInputType(graphene.InputObjectType):
52+
type = graphene.String(default_value="Point")
53+
coordinates = graphene.List(graphene.Float, required=True)
5554

55+
56+
class PolygonFieldType(_CoordinatesTypeField):
5657
coordinates = graphene.List(graphene.List(graphene.List(graphene.Float)))
5758

5859

5960
class MultiPolygonFieldType(_CoordinatesTypeField):
60-
6161
coordinates = graphene.List(
6262
graphene.List(graphene.List(graphene.List(graphene.Float)))
6363
)

graphene_mongo/converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,22 @@ def convert_field_to_jsonstring(field, registry=None):
8282

8383
@convert_mongoengine_field.register(mongoengine.PointField)
8484
def convert_point_to_field(field, registry=None):
85-
return graphene.Field(advanced_types.PointFieldType)
85+
return graphene.Field(advanced_types.PointFieldType, required=field.required)
8686

8787

8888
@convert_mongoengine_field.register(mongoengine.PolygonField)
8989
def convert_polygon_to_field(field, registry=None):
90-
return graphene.Field(advanced_types.PolygonFieldType)
90+
return graphene.Field(advanced_types.PolygonFieldType, required=field.required)
9191

9292

9393
@convert_mongoengine_field.register(mongoengine.MultiPolygonField)
9494
def convert_multipolygon_to_field(field, register=None):
95-
return graphene.Field(advanced_types.MultiPolygonFieldType)
95+
return graphene.Field(advanced_types.MultiPolygonFieldType, required=field.required)
9696

9797

9898
@convert_mongoengine_field.register(mongoengine.FileField)
9999
def convert_file_to_field(field, registry=None):
100-
return graphene.Field(advanced_types.FileFieldType)
100+
return graphene.Field(advanced_types.FileFieldType, required=field.required)
101101

102102

103103
@convert_mongoengine_field.register(mongoengine.ListField)
@@ -360,7 +360,7 @@ def dynamic_type():
360360
return None
361361
if isinstance(field, mongoengine.EmbeddedDocumentField):
362362
return graphene.Field(_type,
363-
description=get_field_description(field, registry))
363+
description=get_field_description(field, registry), required=field.required)
364364
field_resolver = None
365365
required = False
366366
if field.db_field is not None:

graphene_mongo/fields.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
FileFieldType,
2525
PointFieldType,
2626
MultiPolygonFieldType,
27-
PolygonFieldType,
27+
PolygonFieldType, PointFieldInputType,
2828
)
2929
from .converter import convert_mongoengine_field, MongoEngineConversionError
3030
from .registry import get_global_registry
@@ -79,7 +79,7 @@ def registry(self):
7979
def args(self):
8080
return to_arguments(
8181
self._base_args or OrderedDict(),
82-
dict(dict(self.field_args, **self.reference_args), **self.filter_args),
82+
dict(dict(self.field_args, **self.advance_args), **self.filter_args),
8383
)
8484

8585
@args.setter
@@ -149,35 +149,42 @@ def filter_args(self):
149149
if self._type._meta.filter_fields:
150150
for field, filter_collection in self._type._meta.filter_fields.items():
151151
for each in filter_collection:
152-
filter_type = getattr(
153-
graphene,
154-
str(self._type._meta.fields[field].type).replace("!", ""),
155-
)
156-
152+
if str(self._type._meta.fields[field].type) == 'PointFieldType':
153+
if each == 'max_distance':
154+
filter_type = graphene.Int
155+
else:
156+
filter_type = PointFieldInputType
157+
else:
158+
filter_type = getattr(
159+
graphene,
160+
str(self._type._meta.fields[field].type).replace("!", ""),
161+
)
157162
# handle special cases
158163
advanced_filter_types = {
159164
"in": graphene.List(filter_type),
160165
"nin": graphene.List(filter_type),
161166
"all": graphene.List(filter_type),
162167
}
163-
164168
filter_type = advanced_filter_types.get(each, filter_type)
165169
filter_args[field + "__" + each] = graphene.Argument(
166170
type=filter_type
167171
)
168-
169172
return filter_args
170173

171174
@property
172-
def reference_args(self):
173-
def get_reference_field(r, kv):
175+
def advance_args(self):
176+
def get_advance_field(r, kv):
174177
field = kv[1]
175178
mongo_field = getattr(self.model, kv[0], None)
179+
if isinstance(mongo_field, mongoengine.PointField):
180+
r.update({kv[0]: graphene.Argument(PointFieldInputType)})
181+
return r
176182
if isinstance(
177183
mongo_field,
178-
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
184+
(mongoengine.LazyReferenceField, mongoengine.ReferenceField, mongoengine.GenericReferenceField),
179185
):
180-
field = convert_mongoengine_field(mongo_field, self.registry)
186+
r.update({kv[0]: graphene.ID()})
187+
return r
181188
if isinstance(mongo_field, mongoengine.GenericReferenceField):
182189
r.update({kv[0]: graphene.ID()})
183190
return r
@@ -192,7 +199,7 @@ def get_reference_field(r, kv):
192199

193200
return r
194201

195-
return reduce(get_reference_field, self.fields.items(), {})
202+
return reduce(get_advance_field, self.fields.items(), {})
196203

197204
@property
198205
def fields(self):
@@ -220,6 +227,12 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
220227
reference_obj = get_document(arg["_cls"])(
221228
pk=arg["_ref"].id)
222229
hydrated_references[arg_name] = reference_obj
230+
elif '__near' in arg_name and isinstance(getattr(self.model, arg_name.split('__')[0]),
231+
mongoengine.fields.PointField):
232+
location = args.pop(arg_name, None)
233+
hydrated_references[arg_name] = location["coordinates"]
234+
if (arg_name.split('__')[0] + "__max_distance") not in args:
235+
hydrated_references[arg_name.split('__')[0] + "__max_distance"] = 10000
223236
elif arg_name == "id":
224237
hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
225238
args.update(hydrated_references)
@@ -381,10 +394,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
381394
self.filter_args.keys()):
382395
args_copy.pop(arg_name)
383396
if arg_name == '_id' and isinstance(arg, dict):
384-
args_copy['pk__in'] = arg['$in']
397+
operation = list(arg.keys())[0]
398+
args_copy['pk' + operation.replace('$', '__')] = arg[operation]
385399
if '.' in arg_name:
386400
operation = list(arg.keys())[0]
387401
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation]
402+
else:
403+
operations = ["$lte", "$gte", "$ne", "$in"]
404+
if isinstance(arg, dict) and any(op in arg for op in operations):
405+
operation = list(arg.keys())[0]
406+
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
407+
del args_copy[arg_name]
388408
return self.default_resolver(root, info, required_fields, **args_copy)
389409
else:
390410
return resolved

graphene_mongo/tests/test_fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_article_field_args():
99
assert set(field.field_args.keys()) == field_args
1010

1111
reference_args = {"editor", "reporter"}
12-
assert set(field.reference_args.keys()) == reference_args
12+
assert all(item in set(field.advance_args.keys()) for item in reference_args)
1313

1414
default_args = {"after", "last", "first", "before"}
1515
args = field_args | reference_args | default_args

0 commit comments

Comments
 (0)