Skip to content

Commit 96156c8

Browse files
committed
Check for encrypted fields from unencrypted conn
Re-adding test removed in 46ca9dc as assertEncrypted class method.
1 parent a0cd197 commit 96156c8

File tree

1 file changed

+79
-6
lines changed

1 file changed

+79
-6
lines changed

tests/encryption_/test_fields.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import datetime
22
from decimal import Decimal
33

4+
import pymongo
5+
from bson.binary import Binary
6+
from django.conf import settings
7+
from django.db import connections
8+
from django.db.models import Model
9+
410
from django_mongodb_backend.fields import EncryptedCharField
511

612
from .models import (
@@ -32,22 +38,69 @@
3238
from .test_base import EncryptionTestCase
3339

3440

35-
class EncryptedEmbeddedModelTests(EncryptionTestCase):
41+
class EncryptedFieldTests(EncryptionTestCase):
42+
def assertEncrypted(self, model_or_instance, field_name):
43+
"""
44+
Check if the field value in the database is stored as Binary.
45+
Works with either a Django model instance or a model class.
46+
"""
47+
48+
conn_params = connections["encrypted"].get_connection_params()
49+
db_name = settings.DATABASES["encrypted"]["NAME"]
50+
51+
if conn_params.pop("auto_encryption_opts", False):
52+
with pymongo.MongoClient(**conn_params) as new_connection:
53+
if hasattr(model_or_instance, "_meta"):
54+
collection_name = model_or_instance._meta.db_table
55+
else:
56+
self.fail(f"Object {model_or_instance!r} is not a Django model or instance")
57+
58+
collection = new_connection[db_name][collection_name]
59+
60+
# If it's an instance of a Django model, narrow to that _id
61+
if isinstance(model_or_instance, Model):
62+
docs = collection.find(
63+
{"_id": model_or_instance.pk, field_name: {"$exists": True}}
64+
)
65+
else:
66+
# Otherwise it's a model class
67+
docs = collection.find({field_name: {"$exists": True}})
68+
69+
found = False
70+
for doc in docs:
71+
found = True
72+
value = doc.get(field_name)
73+
self.assertTrue(
74+
isinstance(value, Binary),
75+
msg=f"Field '{field_name}' in document {doc['_id']} is "
76+
"not encrypted (type={type(value)})",
77+
)
78+
79+
self.assertTrue(
80+
found,
81+
msg=f"No documents with field '{field_name}' found in '{{collection_name}}'",
82+
)
83+
84+
else:
85+
self.fail("auto_encryption_opts is not configured; encryption not enabled.")
86+
87+
88+
class EncryptedEmbeddedModelTests(EncryptedFieldTests):
3689
def setUp(self):
3790
self.billing = Billing(cc_type="Visa", cc_number="4111111111111111")
3891
self.patient_record = PatientRecord(ssn="123-45-6789", billing=self.billing)
3992
self.patient = Patient.objects.create(
4093
patient_name="John Doe", patient_id=123456789, patient_record=self.patient_record
4194
)
4295

43-
def test_patient(self):
96+
def test_object(self):
4497
patient = Patient.objects.get(id=self.patient.id)
4598
self.assertEqual(patient.patient_record.ssn, "123-45-6789")
4699
self.assertEqual(patient.patient_record.billing.cc_type, "Visa")
47100
self.assertEqual(patient.patient_record.billing.cc_number, "4111111111111111")
48101

49102

50-
class EncryptedEmbeddedModelArrayTests(EncryptionTestCase):
103+
class EncryptedEmbeddedModelArrayTests(EncryptedFieldTests):
51104
def setUp(self):
52105
self.actor1 = Actor(name="Actor One")
53106
self.actor2 = Actor(name="Actor Two")
@@ -56,13 +109,14 @@ def setUp(self):
56109
cast=[self.actor1, self.actor2],
57110
)
58111

59-
def test_movie_actors(self):
112+
def test_array(self):
60113
self.assertEqual(len(self.movie.cast), 2)
61114
self.assertEqual(self.movie.cast[0].name, "Actor One")
62115
self.assertEqual(self.movie.cast[1].name, "Actor Two")
116+
self.assertEncrypted(self.movie, "cast")
63117

64118

65-
class EncryptedFieldTests(EncryptionTestCase):
119+
class EncryptedFieldTests(EncryptedFieldTests):
66120
def assertEquality(self, model_cls, val):
67121
model_cls.objects.create(value=val)
68122
fetched = model_cls.objects.get(value=val)
@@ -80,28 +134,36 @@ def assertRange(self, model_cls, *, low, high, threshold):
80134
# Equality-only fields
81135
def test_binary(self):
82136
self.assertEquality(BinaryModel, b"\x00\x01\x02")
137+
self.assertEncrypted(BinaryModel, "value")
83138

84139
def test_boolean(self):
85140
self.assertEquality(BooleanModel, True)
141+
self.assertEncrypted(BooleanModel, "value")
86142

87143
def test_char(self):
88144
self.assertEquality(CharModel, "hello")
145+
self.assertEncrypted(CharModel, "value")
89146

90147
def test_email(self):
91148
self.assertEquality(EmailModel, "[email protected]")
149+
self.assertEncrypted(EmailModel, "value")
92150

93151
def test_ip(self):
94152
self.assertEquality(GenericIPAddressModel, "192.168.0.1")
153+
self.assertEncrypted(GenericIPAddressModel, "value")
95154

96155
def test_text(self):
97156
self.assertEquality(TextModel, "some text")
157+
self.assertEncrypted(TextModel, "value")
98158

99159
def test_url(self):
100160
self.assertEquality(URLModel, "https://example.com")
161+
self.assertEncrypted(URLModel, "value")
101162

102163
# Range fields
103164
def test_big_integer(self):
104165
self.assertRange(BigIntegerModel, low=100, high=200, threshold=150)
166+
self.assertEncrypted(BigIntegerModel, "value")
105167

106168
def test_date(self):
107169
self.assertRange(
@@ -110,6 +172,7 @@ def test_date(self):
110172
high=datetime.date(2024, 6, 10),
111173
threshold=datetime.date(2024, 6, 5),
112174
)
175+
self.assertEncrypted(DateModel, "value")
113176

114177
def test_datetime(self):
115178
self.assertRange(
@@ -118,6 +181,7 @@ def test_datetime(self):
118181
high=datetime.datetime(2024, 6, 2, 12, 0),
119182
threshold=datetime.datetime(2024, 6, 2, 0, 0),
120183
)
184+
self.assertEncrypted(DateTimeModel, "value")
121185

122186
def test_decimal(self):
123187
self.assertRange(
@@ -126,6 +190,7 @@ def test_decimal(self):
126190
high=Decimal("200.50"),
127191
threshold=Decimal("150"),
128192
)
193+
self.assertEncrypted(DecimalModel, "value")
129194

130195
def test_duration(self):
131196
self.assertRange(
@@ -134,24 +199,31 @@ def test_duration(self):
134199
high=datetime.timedelta(days=10),
135200
threshold=datetime.timedelta(days=5),
136201
)
202+
self.assertEncrypted(DurationModel, "value")
137203

138204
def test_float(self):
139205
self.assertRange(FloatModel, low=1.23, high=4.56, threshold=3.0)
206+
self.assertEncrypted(FloatModel, "value")
140207

141208
def test_integer(self):
142209
self.assertRange(IntegerModel, low=5, high=10, threshold=7)
210+
self.assertEncrypted(IntegerModel, "value")
143211

144212
def test_positive_big_integer(self):
145213
self.assertRange(PositiveBigIntegerModel, low=100, high=500, threshold=200)
214+
self.assertEncrypted(PositiveBigIntegerModel, "value")
146215

147216
def test_positive_integer(self):
148217
self.assertRange(PositiveIntegerModel, low=10, high=20, threshold=15)
218+
self.assertEncrypted(PositiveIntegerModel, "value")
149219

150220
def test_positive_small_integer(self):
151221
self.assertRange(PositiveSmallIntegerModel, low=5, high=8, threshold=6)
222+
self.assertEncrypted(PositiveSmallIntegerModel, "value")
152223

153224
def test_small_integer(self):
154225
self.assertRange(SmallIntegerModel, low=-5, high=2, threshold=0)
226+
self.assertEncrypted(SmallIntegerModel, "value")
155227

156228
def test_time(self):
157229
self.assertRange(
@@ -160,9 +232,10 @@ def test_time(self):
160232
high=datetime.time(15, 0),
161233
threshold=datetime.time(12, 0),
162234
)
235+
self.assertEncrypted(TimeModel, "value")
163236

164237

165-
class EncryptedFieldMixinTests(EncryptionTestCase):
238+
class EncryptedFieldMixinTests(EncryptedFieldTests):
166239
def test_null_true_raises_error(self):
167240
with self.assertRaisesMessage(
168241
ValueError, "'null=True' is not supported for encrypted fields."

0 commit comments

Comments
 (0)