Skip to content

Commit 437cfe0

Browse files
committed
Add support for EncryptedEmbeddedModelArrayField
1 parent 720918f commit 437cfe0

File tree

5 files changed

+80
-6
lines changed

5 files changed

+80
-6
lines changed

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EncryptedDecimalField,
1414
EncryptedDurationField,
1515
EncryptedEmailField,
16+
EncryptedEmbeddedModelArrayField,
1617
EncryptedEmbeddedModelField,
1718
EncryptedFieldMixin,
1819
EncryptedFloatField,
@@ -44,6 +45,7 @@
4445
"EncryptedDecimalField",
4546
"EncryptedDurationField",
4647
"EncryptedEmailField",
48+
"EncryptedEmbeddedModelArrayField",
4749
"EncryptedEmbeddedModelField",
4850
"EncryptedFieldMixin",
4951
"EncryptedFloatField",

django_mongodb_backend/fields/encryption.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from django.db import models
22

3-
from django_mongodb_backend.fields import EmbeddedModelField
3+
from django_mongodb_backend.fields import EmbeddedModelArrayField, EmbeddedModelField
4+
5+
6+
class EncryptedEmbeddedModelArrayField(EmbeddedModelArrayField):
7+
encrypted = True
48

59

610
class EncryptedEmbeddedModelField(EmbeddedModelField):

django_mongodb_backend/schema.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from django_mongodb_backend.indexes import SearchIndex
1010

11-
from .fields import EmbeddedModelField
11+
from .fields import EmbeddedModelArrayField, EmbeddedModelField
1212
from .gis.schema import GISSchemaEditor
1313
from .query import wrap_database_errors
1414
from .utils import OperationCollector, model_has_encrypted_fields
@@ -551,6 +551,7 @@ def _get_encrypted_fields(
551551
new_key_alt_name = f"{key_alt_name}.{field.column}"
552552
path = f"{path_prefix}.{field.column}" if path_prefix else field.column
553553

554+
# --- Embedded Single Document ---
554555
if isinstance(field, EmbeddedModelField):
555556
if getattr(field, "encrypted", False):
556557
# Entire embedded object encrypted
@@ -582,7 +583,39 @@ def _get_encrypted_fields(
582583
field_list.extend(embedded_result["fields"])
583584
continue
584585

585-
# Leaf encrypted field
586+
# --- Array of Embedded Documents ---
587+
if isinstance(field, EmbeddedModelArrayField):
588+
if getattr(field, "encrypted", False):
589+
# Entire array contents encrypted - flat entry
590+
data_key = self._get_data_key(
591+
client_encryption,
592+
key_vault_collection,
593+
create_data_keys,
594+
kms_provider,
595+
master_key,
596+
new_key_alt_name,
597+
)
598+
field_dict = {
599+
"bsonType": "array",
600+
"path": path,
601+
"keyId": data_key,
602+
}
603+
if getattr(field, "queries", False):
604+
field_dict["queries"] = field.queries
605+
field_list.append(field_dict)
606+
else:
607+
# Recurse into embedded model for fields inside array elements
608+
embedded_result = self._get_encrypted_fields(
609+
field.embedded_model,
610+
create_data_keys=create_data_keys,
611+
key_alt_name=new_key_alt_name,
612+
path_prefix=path, # array prefix in path
613+
)
614+
if embedded_result and embedded_result.get("fields"):
615+
field_list.extend(embedded_result["fields"])
616+
continue
617+
618+
# --- Leaf encrypted field ---
586619
if getattr(field, "encrypted", False):
587620
data_key = self._get_data_key(
588621
client_encryption,

tests/encryption_/models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
EncryptedDecimalField,
1212
EncryptedDurationField,
1313
EncryptedEmailField,
14+
EncryptedEmbeddedModelArrayField,
1415
EncryptedEmbeddedModelField,
1516
EncryptedFloatField,
1617
EncryptedGenericIPAddressField,
@@ -32,6 +33,7 @@ class Meta:
3233
required_db_features = {"supports_queryable_encryption"}
3334

3435

36+
# Embedded models
3537
class Patient(EncryptedTestModel):
3638
patient_name = models.CharField(max_length=255)
3739
patient_id = models.BigIntegerField()
@@ -52,7 +54,23 @@ class Billing(EmbeddedModel):
5254
cc_number = models.CharField(max_length=20)
5355

5456

55-
# Equality-queryable fields
57+
# Embedded array models
58+
class Actor(EmbeddedModel):
59+
name = models.CharField(max_length=100)
60+
61+
62+
class Movie(EncryptedTestModel):
63+
title = models.CharField(max_length=200)
64+
plot = models.TextField(blank=True)
65+
runtime = models.IntegerField(default=0)
66+
released = models.DateTimeField("release date", null=True, blank=True)
67+
cast = EncryptedEmbeddedModelArrayField(Actor, null=True, blank=True)
68+
69+
def __str__(self):
70+
return self.title
71+
72+
73+
# Equality-queryable field models
5674
class BinaryModel(EncryptedTestModel):
5775
value = EncryptedBinaryField(queries={"queryType": "equality"})
5876

@@ -81,7 +99,7 @@ class URLModel(EncryptedTestModel):
8199
value = EncryptedURLField(max_length=500, queries={"queryType": "equality"})
82100

83101

84-
# Range-queryable fields (also support equality)
102+
# Range-queryable field models
85103
class BigIntegerModel(EncryptedTestModel):
86104
value = EncryptedBigIntegerField(queries={"queryType": "range"})
87105

tests/encryption_/test_fields.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from django_mongodb_backend.fields import EncryptedCharField
55

66
from .models import (
7+
Actor,
78
BigIntegerModel,
89
Billing,
910
BinaryModel,
@@ -17,6 +18,7 @@
1718
FloatModel,
1819
GenericIPAddressModel,
1920
IntegerModel,
21+
Movie,
2022
Patient,
2123
PatientRecord,
2224
PositiveBigIntegerModel,
@@ -30,7 +32,7 @@
3032
from .test_base import EncryptionTestCase
3133

3234

33-
class PatientModelTests(EncryptionTestCase):
35+
class EncryptedEmbeddedModelTests(EncryptionTestCase):
3436
def setUp(self):
3537
self.billing = Billing(cc_type="Visa", cc_number="4111111111111111")
3638
self.patient_record = PatientRecord(ssn="123-45-6789", billing=self.billing)
@@ -45,6 +47,21 @@ def test_patient(self):
4547
self.assertEqual(patient.patient_record.billing.cc_number, "4111111111111111")
4648

4749

50+
class EncryptedEmbeddedModelArrayTests(EncryptionTestCase):
51+
def setUp(self):
52+
self.actor1 = Actor(name="Actor One")
53+
self.actor2 = Actor(name="Actor Two")
54+
self.movie = Movie.objects.create(
55+
title="Sample Movie",
56+
cast=[self.actor1, self.actor2],
57+
)
58+
59+
def test_movie_actors(self):
60+
self.assertEqual(len(self.movie.cast), 2)
61+
self.assertEqual(self.movie.cast[0].name, "Actor One")
62+
self.assertEqual(self.movie.cast[1].name, "Actor Two")
63+
64+
4865
class EncryptedFieldTests(EncryptionTestCase):
4966
def assertEquality(self, model_cls, val):
5067
model_cls.objects.create(value=val)

0 commit comments

Comments
 (0)