1
1
import datetime
2
2
from decimal import Decimal
3
3
4
+ import pymongo
5
+ from bson .binary import Binary
6
+ from django .conf import settings
7
+ from django .db import connections
8
+ from django .db .models import Model
9
+
4
10
from django_mongodb_backend .fields import EncryptedCharField
5
11
6
12
from .models import (
32
38
from .test_base import EncryptionTestCase
33
39
34
40
35
- class EncryptedEmbeddedModelTests (EncryptionTestCase ):
41
+ class EncryptedFieldTests (EncryptionTestCase ):
42
+ def assertEncrypted (self , model_or_instance , field_name ):
43
+ """
44
+ Check if the field value in the database is stored as Binary.
45
+ Works with either a Django model instance or a model class.
46
+ """
47
+
48
+ conn_params = connections ["encrypted" ].get_connection_params ()
49
+ db_name = settings .DATABASES ["encrypted" ]["NAME" ]
50
+
51
+ if conn_params .pop ("auto_encryption_opts" , False ):
52
+ with pymongo .MongoClient (** conn_params ) as new_connection :
53
+ if hasattr (model_or_instance , "_meta" ):
54
+ collection_name = model_or_instance ._meta .db_table
55
+ else :
56
+ self .fail (f"Object { model_or_instance !r} is not a Django model or instance" )
57
+
58
+ collection = new_connection [db_name ][collection_name ]
59
+
60
+ # If it's an instance of a Django model, narrow to that _id
61
+ if isinstance (model_or_instance , Model ):
62
+ docs = collection .find (
63
+ {"_id" : model_or_instance .pk , field_name : {"$exists" : True }}
64
+ )
65
+ else :
66
+ # Otherwise it's a model class
67
+ docs = collection .find ({field_name : {"$exists" : True }})
68
+
69
+ found = False
70
+ for doc in docs :
71
+ found = True
72
+ value = doc .get (field_name )
73
+ self .assertTrue (
74
+ isinstance (value , Binary ),
75
+ msg = f"Field '{ field_name } ' in document { doc ['_id' ]} is "
76
+ "not encrypted (type={type(value)})" ,
77
+ )
78
+
79
+ self .assertTrue (
80
+ found ,
81
+ msg = f"No documents with field '{ field_name } ' found in '{{collection_name}}'" ,
82
+ )
83
+
84
+ else :
85
+ self .fail ("auto_encryption_opts is not configured; encryption not enabled." )
86
+
87
+
88
+ class EncryptedEmbeddedModelTests (EncryptedFieldTests ):
36
89
def setUp (self ):
37
90
self .billing = Billing (cc_type = "Visa" , cc_number = "4111111111111111" )
38
91
self .patient_record = PatientRecord (ssn = "123-45-6789" , billing = self .billing )
39
92
self .patient = Patient .objects .create (
40
93
patient_name = "John Doe" , patient_id = 123456789 , patient_record = self .patient_record
41
94
)
42
95
43
- def test_patient (self ):
96
+ def test_object (self ):
44
97
patient = Patient .objects .get (id = self .patient .id )
45
98
self .assertEqual (patient .patient_record .ssn , "123-45-6789" )
46
99
self .assertEqual (patient .patient_record .billing .cc_type , "Visa" )
47
100
self .assertEqual (patient .patient_record .billing .cc_number , "4111111111111111" )
48
101
49
102
50
- class EncryptedEmbeddedModelArrayTests (EncryptionTestCase ):
103
+ class EncryptedEmbeddedModelArrayTests (EncryptedFieldTests ):
51
104
def setUp (self ):
52
105
self .actor1 = Actor (name = "Actor One" )
53
106
self .actor2 = Actor (name = "Actor Two" )
@@ -56,13 +109,14 @@ def setUp(self):
56
109
cast = [self .actor1 , self .actor2 ],
57
110
)
58
111
59
- def test_movie_actors (self ):
112
+ def test_array (self ):
60
113
self .assertEqual (len (self .movie .cast ), 2 )
61
114
self .assertEqual (self .movie .cast [0 ].name , "Actor One" )
62
115
self .assertEqual (self .movie .cast [1 ].name , "Actor Two" )
116
+ self .assertEncrypted (self .movie , "cast" )
63
117
64
118
65
- class EncryptedFieldTests (EncryptionTestCase ):
119
+ class EncryptedFieldTests (EncryptedFieldTests ):
66
120
def assertEquality (self , model_cls , val ):
67
121
model_cls .objects .create (value = val )
68
122
fetched = model_cls .objects .get (value = val )
@@ -80,28 +134,36 @@ def assertRange(self, model_cls, *, low, high, threshold):
80
134
# Equality-only fields
81
135
def test_binary (self ):
82
136
self .assertEquality (BinaryModel , b"\x00 \x01 \x02 " )
137
+ self .assertEncrypted (BinaryModel , "value" )
83
138
84
139
def test_boolean (self ):
85
140
self .assertEquality (BooleanModel , True )
141
+ self .assertEncrypted (BooleanModel , "value" )
86
142
87
143
def test_char (self ):
88
144
self .assertEquality (CharModel , "hello" )
145
+ self .assertEncrypted (CharModel , "value" )
89
146
90
147
def test_email (self ):
91
148
self .
assertEquality (
EmailModel ,
"[email protected] " )
149
+ self .assertEncrypted (EmailModel , "value" )
92
150
93
151
def test_ip (self ):
94
152
self .assertEquality (GenericIPAddressModel , "192.168.0.1" )
153
+ self .assertEncrypted (GenericIPAddressModel , "value" )
95
154
96
155
def test_text (self ):
97
156
self .assertEquality (TextModel , "some text" )
157
+ self .assertEncrypted (TextModel , "value" )
98
158
99
159
def test_url (self ):
100
160
self .assertEquality (URLModel , "https://example.com" )
161
+ self .assertEncrypted (URLModel , "value" )
101
162
102
163
# Range fields
103
164
def test_big_integer (self ):
104
165
self .assertRange (BigIntegerModel , low = 100 , high = 200 , threshold = 150 )
166
+ self .assertEncrypted (BigIntegerModel , "value" )
105
167
106
168
def test_date (self ):
107
169
self .assertRange (
@@ -110,6 +172,7 @@ def test_date(self):
110
172
high = datetime .date (2024 , 6 , 10 ),
111
173
threshold = datetime .date (2024 , 6 , 5 ),
112
174
)
175
+ self .assertEncrypted (DateModel , "value" )
113
176
114
177
def test_datetime (self ):
115
178
self .assertRange (
@@ -118,6 +181,7 @@ def test_datetime(self):
118
181
high = datetime .datetime (2024 , 6 , 2 , 12 , 0 ),
119
182
threshold = datetime .datetime (2024 , 6 , 2 , 0 , 0 ),
120
183
)
184
+ self .assertEncrypted (DateTimeModel , "value" )
121
185
122
186
def test_decimal (self ):
123
187
self .assertRange (
@@ -126,6 +190,7 @@ def test_decimal(self):
126
190
high = Decimal ("200.50" ),
127
191
threshold = Decimal ("150" ),
128
192
)
193
+ self .assertEncrypted (DecimalModel , "value" )
129
194
130
195
def test_duration (self ):
131
196
self .assertRange (
@@ -134,24 +199,31 @@ def test_duration(self):
134
199
high = datetime .timedelta (days = 10 ),
135
200
threshold = datetime .timedelta (days = 5 ),
136
201
)
202
+ self .assertEncrypted (DurationModel , "value" )
137
203
138
204
def test_float (self ):
139
205
self .assertRange (FloatModel , low = 1.23 , high = 4.56 , threshold = 3.0 )
206
+ self .assertEncrypted (FloatModel , "value" )
140
207
141
208
def test_integer (self ):
142
209
self .assertRange (IntegerModel , low = 5 , high = 10 , threshold = 7 )
210
+ self .assertEncrypted (IntegerModel , "value" )
143
211
144
212
def test_positive_big_integer (self ):
145
213
self .assertRange (PositiveBigIntegerModel , low = 100 , high = 500 , threshold = 200 )
214
+ self .assertEncrypted (PositiveBigIntegerModel , "value" )
146
215
147
216
def test_positive_integer (self ):
148
217
self .assertRange (PositiveIntegerModel , low = 10 , high = 20 , threshold = 15 )
218
+ self .assertEncrypted (PositiveIntegerModel , "value" )
149
219
150
220
def test_positive_small_integer (self ):
151
221
self .assertRange (PositiveSmallIntegerModel , low = 5 , high = 8 , threshold = 6 )
222
+ self .assertEncrypted (PositiveSmallIntegerModel , "value" )
152
223
153
224
def test_small_integer (self ):
154
225
self .assertRange (SmallIntegerModel , low = - 5 , high = 2 , threshold = 0 )
226
+ self .assertEncrypted (SmallIntegerModel , "value" )
155
227
156
228
def test_time (self ):
157
229
self .assertRange (
@@ -160,9 +232,10 @@ def test_time(self):
160
232
high = datetime .time (15 , 0 ),
161
233
threshold = datetime .time (12 , 0 ),
162
234
)
235
+ self .assertEncrypted (TimeModel , "value" )
163
236
164
237
165
- class EncryptedFieldMixinTests (EncryptionTestCase ):
238
+ class EncryptedFieldMixinTests (EncryptedFieldTests ):
166
239
def test_null_true_raises_error (self ):
167
240
with self .assertRaisesMessage (
168
241
ValueError , "'null=True' is not supported for encrypted fields."
0 commit comments