Skip to content

Commit 01d5485

Browse files
committed
Move encrypted_fields_map to schema (2/x)
- Refactor tests
1 parent 10a361e commit 01d5485

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

django_mongodb_backend/schema.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,20 +421,26 @@ def _field_should_have_unique(self, field):
421421

422422
def _create_collection(self, model):
423423
"""
424-
Create a collection or encrypted collection for the model.
424+
If the model is not encrypted, create a normal collection otherwise
425+
create an encrypted collection with the encrypted fields map.
425426
"""
426427

427-
if hasattr(model, "encrypted"):
428+
if not hasattr(model, "encrypted"):
429+
self.get_database().create_collection(model._meta.db_table)
430+
else:
431+
# TODO: Route to the encrypted database connection.
428432
auto_encryption_opts = self.connection.settings_dict.get("OPTIONS", {}).get(
429433
"auto_encryption_opts"
430434
)
431435
client = self.connection.connection
436+
432437
client_encryption = get_client_encryption(auto_encryption_opts, client)
433438
client_encryption.create_encrypted_collection(
434439
client.database,
435440
model._meta.db_table,
436-
{"fields": []},
441+
self._get_encrypted_fields_map(model),
437442
"local", # TODO: KMS provider should be configurable
438443
)
439-
else:
440-
self.get_database().create_collection(model._meta.db_table)
444+
445+
def _get_encrypted_fields_map(self, model):
446+
return {"fields": []}

tests/encryption_/tests.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,24 @@
1+
from django.db import connection
12
from django.test import TestCase
23

34
from .models import Person
45

56

67
class EncryptedModelTests(TestCase):
7-
databases = ["encryption"]
8-
98
@classmethod
109
def setUpTestData(cls):
11-
cls.objs = [Person.objects.create()]
12-
13-
def test_encrypted_fields_map_on_class(self):
14-
expected = {
15-
"fields": {
16-
"ssn": "EncryptedCharField",
17-
}
18-
}
19-
self.assertEqual(Person.encrypted_fields_map, expected)
10+
cls.person = Person(ssn="123-45-6789")
2011

2112
def test_encrypted_fields_map_on_instance(self):
22-
instance = Person(ssn="123-45-6789")
2313
expected = {
2414
"fields": {
2515
"ssn": "EncryptedCharField",
2616
}
2717
}
28-
self.assertEqual(instance.encrypted_fields_map, expected)
18+
with connection.schema_editor() as editor:
19+
self.assertEqual(editor._get_encrypted_fields_map(self.person), expected)
2920

3021
def test_non_encrypted_fields_not_included(self):
31-
encrypted_field_names = Person.encrypted_fields_map.get("fields").keys()
32-
self.assertNotIn("name", encrypted_field_names)
22+
with connection.schema_editor() as editor:
23+
encrypted_field_names = editor._get_encrypted_fields_map(self.person).get("fields")
24+
self.assertNotIn("name", encrypted_field_names)

0 commit comments

Comments
 (0)