Skip to content

Commit 20c3b25

Browse files
author
Kévin Etienne
authored
Merge pull request #44 from peterfarrell/feature_custom_manager
Feature custom manager
2 parents a10ee8d + 6094ff7 commit 20c3b25

File tree

6 files changed

+165
-9
lines changed

6 files changed

+165
-9
lines changed

README.md

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,16 @@ INSTALLED_APPS = (
127127

128128
### Usage
129129

130-
#### Model definition
130+
#### Model / Manager definition
131131

132132
```python
133133
from django.db import models
134134

135-
from pgcrypto import fields
135+
from pgcrypto import fields, managers
136+
137+
class MyModelManager(managers.PGPManager):
138+
pass
139+
136140

137141
class MyModel(models.Model):
138142
digest_field = fields.TextDigestField()
@@ -144,6 +148,8 @@ class MyModel(models.Model):
144148
pgp_sym_field = fields.TextPGPSymmetricKeyField()
145149
date_pgp_sym_field = fields.DatePGPSymmetricKeyField()
146150
datetime_pgp_sym_field = fields.DateTimePGPSymmetricKeyField()
151+
152+
objects = MyModelManager()
147153
```
148154

149155
#### Encrypting
@@ -155,7 +161,36 @@ Example:
155161
>>> MyModel.objects.create(value='Value to be encrypted...')
156162
```
157163

158-
#### Decrypting
164+
#### Decryption using custom model managers
165+
166+
If you use the bundled `PGPManager` with your custom model manager, all encrypted
167+
fields will automatically decrypted for you (except for hash fields which are one
168+
way).
169+
170+
N.B. The bundled manager does not support decryption of fields from FK joins. For
171+
example if the `MyModel` class had a FK to to `AnotherModel` class, no encrypted
172+
fields be decrypted in the joined `AnotherModel`.
173+
174+
It is recommended that you use the bundled `PGPAdmin` class if using the custom
175+
model manager and the Django Admin. The Django Admin performance suffers when
176+
using the bundled custom manager. The `PGPAdmin` disables automatic decryption
177+
for all ORM calls for that admin class.
178+
179+
```python
180+
from django.contrib import admin
181+
182+
from pgcrypto.admin import PGPAdmin
183+
184+
185+
class MyModelAdmin(admin.ModelAdmin, PGPAdmin):
186+
# Your admin code here
187+
```
188+
189+
190+
#### Decrypting using aggregates
191+
192+
This is useful if you are not using the custom manager or need to decrypt fields
193+
coming from joined FK fields.
159194

160195
##### PGP fields
161196

pgcrypto/admin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class PGPAdmin(object):
2+
3+
def get_queryset(self, request):
4+
"""Skip any auto decryption when ORM calls are from the admin."""
5+
return self.model.objects.get_queryset(**{'skip_decrypt': True})

pgcrypto/managers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from django.conf import settings
2+
from django.db import models
3+
4+
5+
class PGPManager(models.Manager):
6+
use_for_related_fields = True
7+
use_in_migrations = True
8+
9+
@staticmethod
10+
def _get_pgp_symmetric_decrypt_sql(field):
11+
"""Decrypt sql for symmetric fields using the cast sql if required."""
12+
sql = """pgp_sym_decrypt("{0}"."{1}", '{2}')"""
13+
if hasattr(field, 'cast_sql'):
14+
sql = field.cast_sql % sql
15+
16+
return sql.format(
17+
field.model._meta.db_table,
18+
field.name,
19+
settings.PGCRYPTO_KEY,
20+
)
21+
22+
@staticmethod
23+
def _get_pgp_public_key_decrypt_sql(field):
24+
"""Decrypt sql for public key fields using the cast sql if required."""
25+
sql = """pgp_pub_decrypt("{0}"."{1}", dearmor('{2}'))"""
26+
27+
if hasattr(field, 'cast_sql'):
28+
sql = field.cast_sql % sql
29+
30+
return sql.format(
31+
field.model._meta.db_table,
32+
field.name,
33+
settings.PRIVATE_PGP_KEY,
34+
)
35+
36+
def get_queryset(self, *args, **kwargs):
37+
"""Decryption in queryset through meta programming."""
38+
# importing here otherwise there's a circular reference issue
39+
from pgcrypto.mixins import PGPSymmetricKeyFieldMixin, PGPPublicKeyFieldMixin
40+
41+
skip_decrypt = kwargs.pop('skip_decrypt', None)
42+
43+
qs = super().get_queryset(*args, **kwargs)
44+
45+
# The Django admin skips this process because it's extremely slow
46+
if not skip_decrypt:
47+
select_sql = {}
48+
encrypted_fields = []
49+
50+
for field in self.model._meta.get_fields():
51+
if isinstance(field, PGPSymmetricKeyFieldMixin):
52+
select_sql[field.name] = self._get_pgp_symmetric_decrypt_sql(field)
53+
encrypted_fields.append(field.name)
54+
elif isinstance(field, PGPPublicKeyFieldMixin):
55+
select_sql[field.name] = self._get_pgp_public_key_decrypt_sql(field)
56+
encrypted_fields.append(field.name)
57+
58+
# Django queryset.extra() is used here to add decryption sql to query.
59+
qs = qs.defer(
60+
*encrypted_fields
61+
).extra(
62+
select=select_sql
63+
)
64+
65+
return qs

pgcrypto/mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class PGPPublicKeyFieldMixin(PGPMixin):
8383

8484

8585
class PGPSymmetricKeyFieldMixin(PGPMixin):
86-
"""PGP symmetric key encrypted field mixin for postgred."""
86+
"""PGP symmetric key encrypted field mixin for postgres."""
8787
aggregate = PGPSymmetricKeyAggregate
8888

8989

tests/models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from django.db import models
22

3-
from pgcrypto import fields
3+
from pgcrypto import fields, managers
4+
5+
6+
class EncryptedModelManager(managers.PGPManager):
7+
pass
48

59

610
class EncryptedModel(models.Model):
@@ -20,3 +24,11 @@ class EncryptedModel(models.Model):
2024

2125
class Meta:
2226
app_label = 'tests'
27+
28+
29+
class EncryptedModelWithManager(EncryptedModel):
30+
31+
objects = EncryptedModelManager()
32+
33+
class Meta:
34+
proxy = True

tests/test_fields.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
from django.test import TestCase
55
from incuna_test_utils.utils import field_names
66

7-
from pgcrypto import aggregates, proxy
8-
from pgcrypto import fields
7+
from pgcrypto import aggregates, fields, proxy
98
from .factories import EncryptedModelFactory
109
from .forms import EncryptedForm
11-
from .models import EncryptedModel
12-
10+
from .models import EncryptedModel, EncryptedModelWithManager
1311

1412
KEYED_FIELDS = (fields.TextDigestField, fields.TextHMACField)
1513
EMAIL_PGP_FIELDS = (fields.EmailPGPPublicKeyField, fields.EmailPGPSymmetricKeyField)
@@ -408,3 +406,44 @@ def test_null(self):
408406
for field in fields:
409407
with self.subTest(field=field):
410408
self.assertEqual(getattr(instance, field), None)
409+
410+
411+
class TestPGPManager(TestCase):
412+
"""Test `PGPManager` can be integrated in a `Django` model."""
413+
model = EncryptedModelWithManager
414+
415+
def test_auto_decryption(self):
416+
"""Assert auto decryption via manager."""
417+
expected_string = 'bonjour'
418+
expected_date = date(2016, 9, 1)
419+
expected_datetime = datetime(2016, 9, 1, 0, 0, 0)
420+
421+
EncryptedModelFactory.create(
422+
digest_field=expected_string,
423+
hmac_field=expected_string,
424+
pgp_pub_field=expected_string,
425+
pgp_sym_field=expected_string,
426+
date_pgp_sym_field=expected_date, # Tests cast sql
427+
datetime_pgp_sym_field=expected_datetime, # Tests cast sql
428+
)
429+
430+
instance = self.model.objects.get()
431+
432+
# Using `__dict__` bypasses "on the fly" decryption that normally occurs
433+
# if accessing a field that is not yet decrypted.
434+
# If decryption is not working, we get references to <In_Memory> classes
435+
self.assertEqual(instance.__dict__['pgp_pub_field'], expected_string)
436+
self.assertEqual(instance.__dict__['pgp_sym_field'], expected_string)
437+
self.assertEqual(instance.__dict__['date_pgp_sym_field'], expected_date)
438+
self.assertEqual(instance.__dict__['datetime_pgp_sym_field'], expected_datetime)
439+
440+
# Ensure digest / hmac fields are unaffected
441+
count = self.model.objects.filter(
442+
digest_field__hash_of=expected_string
443+
).count()
444+
self.assertEqual(count, 1)
445+
446+
count = self.model.objects.filter(
447+
hmac_field__hash_of=expected_string
448+
).count()
449+
self.assertEqual(count, 1)

0 commit comments

Comments
 (0)