Skip to content

Commit e398a3d

Browse files
committed
First draft multiple models in a collection; no tests
1 parent 8a57a06 commit e398a3d

File tree

3 files changed

+151
-2
lines changed

3 files changed

+151
-2
lines changed

django_mongodb_backend/managers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from django.db import NotSupportedError
22
from django.db.models.manager import BaseManager
33

4-
from .queryset import MongoQuerySet
4+
from .queryset import MongoQuerySet, MultiMongoQuerySet
55

66

77
class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
88
pass
99

1010

11+
class MultiMongoManager(BaseManager.from_queryset(MultiMongoQuerySet)):
12+
def get_queryset(self):
13+
return super().get_queryset().filter(_t__in=self.model.subclasses())
14+
15+
1116
class EmbeddedModelManager(BaseManager):
1217
"""
1318
Prevent all queryset operations on embedded models since they don't have

django_mongodb_backend/models.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from itertools import chain
2+
3+
from django.core.exceptions import FieldError
14
from django.db import NotSupportedError, models
25

3-
from .managers import EmbeddedModelManager
6+
from .managers import EmbeddedModelManager, MultiMongoManager
47

58

69
class EmbeddedModel(models.Model):
@@ -14,3 +17,136 @@ def delete(self, *args, **kwargs):
1417

1518
def save(self, *args, **kwargs):
1619
raise NotSupportedError("EmbeddedModels cannot be saved.")
20+
21+
22+
class ModelBaseOverride(models.base.ModelBase):
23+
__excluded_fieldnames = ("_t", "id")
24+
25+
def __new__(cls, name, bases, attrs, **kwargs):
26+
"""An override to the ModelBase which inspects inherited Model
27+
definitions and passes down the field names and table reference
28+
from parent to child model.
29+
** REMAINING TODO
30+
- Handle Index Creation
31+
- Tests
32+
"""
33+
parents = [b for b in bases if isinstance(b, models.base.ModelBase)]
34+
35+
# if no ModelBase instances found, this is the first inherited MultiModel
36+
if not parents:
37+
return super().__new__(cls, name, bases, attrs, **kwargs)
38+
39+
# Recursively fetch all fields of a class.
40+
# Only conclude the loop when we get the MultiModel class
41+
# We cannot explicitly pass a reference to the MultiModel class
42+
# because this builds a circluar dependency
43+
holder = bases
44+
traverse = holder[0]
45+
if traverse.__name__ != "MultiModel" and hasattr(traverse, "_meta"):
46+
while traverse and traverse.__name__ != "MultiModel" and hasattr(traverse, "_meta"):
47+
traverse = traverse._meta._bases[0] if traverse._meta._bases else None
48+
holder = (traverse,)
49+
50+
parent_fields = []
51+
52+
# Set up managed + default db if not set
53+
if hasattr(parents[0], "_meta") and parents[0].__name__ != "MultiModel":
54+
if not attrs.get("Meta"):
55+
56+
class Meta:
57+
db_table = parents[0]._meta.db_table
58+
managed = False
59+
60+
attrs["Meta"] = Meta()
61+
62+
elif meta := attrs.get("Meta"):
63+
if not getattr(meta, "db_table", None):
64+
meta.db_table = parents[0]._meta.db_table
65+
if not getattr(meta, "managed", None):
66+
meta.managed = False
67+
parent_fields = set(parents[0]._meta.local_fields + parents[0]._meta.local_many_to_many)
68+
69+
# The parent class will not be passed to the __new__ construction
70+
# because we will leverage Django's multi-table inheritance
71+
# which would lead to more complications on field resolution
72+
new_attrs = {**attrs}
73+
74+
for field in parent_fields:
75+
if not models.base._has_contribute_to_class(field):
76+
if field.name in new_attrs:
77+
raise FieldError(
78+
f"Local field {field.name!r} in class {name!r} clashes with field of "
79+
f"the same name from base class {parents[0].__name__!r}."
80+
)
81+
new_attrs[field.name] = field
82+
83+
# Construct new class without passing the parent reference, but adding
84+
# every new (derived) attribute to the django class
85+
new_cls = super().__new__(cls, name, holder, new_attrs, **kwargs)
86+
87+
new_fields = chain(
88+
new_cls._meta.local_fields,
89+
new_cls._meta.local_many_to_many,
90+
new_cls._meta.private_fields,
91+
)
92+
field_names = {f.name for f in new_fields}
93+
94+
for field in parent_fields:
95+
if field.primary_key or field.name in ModelBaseOverride.__excluded_fieldnames:
96+
continue
97+
if models.base._has_contribute_to_class(field):
98+
if (
99+
field.name in field_names
100+
and field.name not in ModelBaseOverride.__excluded_fieldnames
101+
):
102+
raise FieldError(
103+
f"Local field {field.name!r} in class {name!r} clashes with field of "
104+
f"the same name from base class {parents[0].__name__!r}."
105+
)
106+
107+
# if not hasattr(new_cls, field.name):
108+
new_cls.add_to_class(field.name, field)
109+
# Add each value as a subclass to its parent MultiModel object
110+
for _base in parents:
111+
# equivalent of if _base is MultiModel
112+
if hasattr(_base, "_subclasses"):
113+
_base._subclasses.setdefault(_base, []).append(new_cls)
114+
115+
new_cls._meta._bases = parents
116+
new_cls._meta.parents = {}
117+
return new_cls
118+
119+
120+
class MultiModel(models.Model, metaclass=ModelBaseOverride):
121+
"""Manager handles tracking all inherited subclasses to be used in the MultiMongoManager query"""
122+
123+
_subclasses = {}
124+
125+
def __init_subclass__(cls, **kwargs):
126+
super().__init_subclass__(**kwargs)
127+
for _base in cls.__bases__:
128+
if issubclass(_base, MultiModel):
129+
MultiModel._subclasses.setdefault(_base, []).append(cls)
130+
131+
# Get all the subclasses for my model
132+
@classmethod
133+
def subclasses(cls):
134+
stack = [cls]
135+
acc = set()
136+
while stack:
137+
node = stack.pop()
138+
stack.extend(cls._subclasses.get(node, []))
139+
acc.add(node)
140+
return [obj.__name__ for obj in acc]
141+
142+
_t = models.CharField(max_length=255, editable=False)
143+
objects = MultiMongoManager()
144+
145+
# Save the classname as the _t before saving
146+
def save(self, *args, **kwargs):
147+
if not self._t:
148+
self._t = self.__class__.__name__
149+
super().save(*args, **kwargs)
150+
151+
class Meta:
152+
abstract = True

django_mongodb_backend/queryset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from itertools import chain
22

3+
from django.apps import apps
34
from django.core.exceptions import FieldDoesNotExist
45
from django.db import connections
56
from django.db.models import QuerySet
@@ -13,6 +14,13 @@ def raw_aggregate(self, pipeline, using=None):
1314
return RawQuerySet(pipeline, model=self.model, using=using)
1415

1516

17+
class MultiMongoQuerySet(MongoQuerySet):
18+
def __iter__(self, *args, **kwargs):
19+
for obj in super().__iter__(*args, **kwargs):
20+
model_class = apps.get_model(obj._meta.app_label, obj._t)
21+
yield model_class.objects.get(pk=obj.pk)
22+
23+
1624
class RawQuerySet(BaseRawQuerySet):
1725
def __init__(self, pipeline, model=None, using=None):
1826
super().__init__(pipeline, model=model, using=using)

0 commit comments

Comments
 (0)