1
1
import inspect
2
- import json
3
2
import logging
4
3
import warnings
5
4
from collections import OrderedDict
@@ -419,6 +418,28 @@ def construct_fields_and_filters(
419
418
return fields , filters
420
419
421
420
421
+ class SQLAlchemyPrimaryKeySerializer (object ):
422
+ """
423
+ Serializes/unserializes primary keys
424
+ """
425
+
426
+ DEFAULT = None
427
+
428
+ def __init__ (self , serialize , deserialize ):
429
+ self .serialize = serialize
430
+ self .deserialize = deserialize
431
+
432
+ @classmethod
433
+ def default (cls ):
434
+ if cls .DEFAULT is None :
435
+ cls .DEFAULT = cls (
436
+ serialize = lambda keys : str (tuple (keys )) if len (keys ) > 1 else keys [0 ],
437
+ deserialize = lambda id : id ,
438
+ )
439
+
440
+ return cls .DEFAULT
441
+
442
+
422
443
class SQLAlchemyBase (BaseType ):
423
444
"""
424
445
This class contains initialization code that is common to both ObjectTypes
@@ -442,6 +463,7 @@ def __init_subclass_with_meta__(
442
463
connection_field_factory = None ,
443
464
_meta = None ,
444
465
create_filters = True ,
466
+ serializer = None ,
445
467
** options ,
446
468
):
447
469
# We always want to bypass this hook unless we're defining a concrete
@@ -531,6 +553,12 @@ def __init_subclass_with_meta__(
531
553
532
554
cls .connection = connection # Public way to get the connection
533
555
556
+ if serializer is None :
557
+ cls .serializer = SQLAlchemyPrimaryKeySerializer .default ()
558
+
559
+ else :
560
+ cls .serializer = serializer
561
+
534
562
super (SQLAlchemyBase , cls ).__init_subclass_with_meta__ (
535
563
_meta = _meta , interfaces = interfaces , ** options
536
564
)
@@ -558,11 +586,7 @@ def get_query(cls, info):
558
586
559
587
@classmethod
560
588
def get_node (cls , info , id ):
561
- try :
562
- key = json .loads (id )
563
-
564
- except json .decoder .JSONDecodeError :
565
- return None
589
+ key = cls .serializer .deserialize (id )
566
590
567
591
if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4 :
568
592
try :
@@ -574,7 +598,7 @@ def get_node(cls, info, id):
574
598
if isinstance (session , AsyncSession ):
575
599
576
600
async def get_result () -> Any :
577
- return await session .get (cls ._meta .model , id )
601
+ return await session .get (cls ._meta .model , key )
578
602
579
603
return get_result ()
580
604
try :
@@ -585,7 +609,12 @@ async def get_result() -> Any:
585
609
def resolve_id (self , info ):
586
610
# graphene_type = info.parent_type.graphene_type
587
611
keys = self .__mapper__ .primary_key_from_instance (self )
588
- return json .dumps (keys if len (keys ) > 1 else keys [0 ])
612
+
613
+ try :
614
+ return self .serializer .serialize (keys if len (keys ) > 1 else keys [0 ])
615
+
616
+ except Exception as e :
617
+ raise ValueError (f"Non-serializable primary key: { e } " ) from e
589
618
590
619
@classmethod
591
620
def enum_for_field (cls , field_name ):
0 commit comments