Skip to content

Commit 0b7d678

Browse files
committed
Composite primary key handling
- introduce serializer attribute to SQLAlchemyBase to allow customisation
1 parent 679a051 commit 0b7d678

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

graphene_sqlalchemy/types.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import json
32
import logging
43
import warnings
54
from collections import OrderedDict
@@ -419,6 +418,28 @@ def construct_fields_and_filters(
419418
return fields, filters
420419

421420

421+
class SQLAlchemyPrimaryKeySerializer(object):
422+
"""
423+
Serializes/unserializes primary keys
424+
"""
425+
426+
DEFAULT = None
427+
428+
def __init__(self, serialize, deserialize):
429+
self.serialize = serialize
430+
self.deserialize = deserialize
431+
432+
@classmethod
433+
def default(cls):
434+
if cls.DEFAULT is None:
435+
cls.DEFAULT = cls(
436+
serialize=lambda keys: str(tuple(keys)) if len(keys) > 1 else keys[0],
437+
deserialize=lambda id: id,
438+
)
439+
440+
return cls.DEFAULT
441+
442+
422443
class SQLAlchemyBase(BaseType):
423444
"""
424445
This class contains initialization code that is common to both ObjectTypes
@@ -442,6 +463,7 @@ def __init_subclass_with_meta__(
442463
connection_field_factory=None,
443464
_meta=None,
444465
create_filters=True,
466+
serializer=None,
445467
**options,
446468
):
447469
# We always want to bypass this hook unless we're defining a concrete
@@ -531,6 +553,12 @@ def __init_subclass_with_meta__(
531553

532554
cls.connection = connection # Public way to get the connection
533555

556+
if serializer is None:
557+
cls.serializer = SQLAlchemyPrimaryKeySerializer.default()
558+
559+
else:
560+
cls.serializer = serializer
561+
534562
super(SQLAlchemyBase, cls).__init_subclass_with_meta__(
535563
_meta=_meta, interfaces=interfaces, **options
536564
)
@@ -558,11 +586,7 @@ def get_query(cls, info):
558586

559587
@classmethod
560588
def get_node(cls, info, id):
561-
try:
562-
key = json.loads(id)
563-
564-
except json.decoder.JSONDecodeError:
565-
return None
589+
key = cls.serializer.deserialize(id)
566590

567591
if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
568592
try:
@@ -574,7 +598,7 @@ def get_node(cls, info, id):
574598
if isinstance(session, AsyncSession):
575599

576600
async def get_result() -> Any:
577-
return await session.get(cls._meta.model, id)
601+
return await session.get(cls._meta.model, key)
578602

579603
return get_result()
580604
try:
@@ -585,7 +609,12 @@ async def get_result() -> Any:
585609
def resolve_id(self, info):
586610
# graphene_type = info.parent_type.graphene_type
587611
keys = self.__mapper__.primary_key_from_instance(self)
588-
return json.dumps(keys if len(keys) > 1 else keys[0])
612+
613+
try:
614+
return self.serializer.serialize(keys if len(keys) > 1 else keys[0])
615+
616+
except Exception as e:
617+
raise ValueError(f"Non-serializable primary key: {e}") from e
589618

590619
@classmethod
591620
def enum_for_field(cls, field_name):

0 commit comments

Comments
 (0)