Skip to content

Commit 7d6af33

Browse files
committed
Factor _get_data_key from _get_encrypted_fields
1 parent d16aa89 commit 7d6af33

File tree

1 file changed

+48
-52
lines changed

1 file changed

+48
-52
lines changed

django_mongodb_backend/schema.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -490,15 +490,42 @@ def _create_collection(self, model):
490490
# Unencrypted path
491491
db.create_collection(db_table)
492492

493+
def _get_data_key(
494+
self,
495+
client_encryption,
496+
key_vault_collection,
497+
create_data_keys,
498+
kms_provider,
499+
master_key,
500+
key_alt_name,
501+
):
502+
"""Return an existing or newly-created data key ID for a field."""
503+
if create_data_keys:
504+
if not client_encryption:
505+
raise ImproperlyConfigured("client_encryption is not configured.")
506+
return client_encryption.create_data_key(
507+
kms_provider=kms_provider,
508+
master_key=master_key,
509+
key_alt_names=[key_alt_name],
510+
)
511+
if key_vault_collection is None:
512+
raise ImproperlyConfigured(
513+
f"Encrypted field {key_alt_name} detected but no key vault configured"
514+
)
515+
key_doc = key_vault_collection.find_one({"keyAltNames": key_alt_name})
516+
if not key_doc:
517+
raise ValueError(
518+
f"No key found in keyvault for keyAltName={key_alt_name}. "
519+
"Run with '--create-data-keys' to create missing keys."
520+
)
521+
return key_doc["_id"]
522+
493523
def _get_encrypted_fields(
494524
self, model, create_data_keys=False, key_alt_name=None, path_prefix=None
495525
):
496526
"""
497527
Recursively collect encryption schema data for only encrypted fields in a model.
498528
Returns None if no encrypted fields are found anywhere in the model hierarchy.
499-
500-
key_alt_name is the base path used for keyAltNames.
501-
path_prefix is the dot-notated path inside the document for schema mapping.
502529
"""
503530
connection = self.connection
504531
client = connection.connection
@@ -524,40 +551,24 @@ def _get_encrypted_fields(
524551
new_key_alt_name = f"{key_alt_name}.{field.column}"
525552
path = f"{path_prefix}.{field.column}" if path_prefix else field.column
526553

527-
# --- EmbeddedModelField ---
528554
if isinstance(field, EmbeddedModelField):
529555
if getattr(field, "encrypted", False):
530-
# Entire sub-object encrypted
531-
if create_data_keys:
532-
if not client_encryption:
533-
raise ImproperlyConfigured("client_encryption is not configured.")
534-
data_key = client_encryption.create_data_key(
535-
kms_provider=kms_provider,
536-
master_key=master_key,
537-
key_alt_names=[new_key_alt_name],
538-
)
539-
else:
540-
if key_vault_collection is None:
541-
raise ImproperlyConfigured(
542-
f"Encrypted field {new_key_alt_name} detected "
543-
"but no key vault configured"
544-
)
545-
key_doc = key_vault_collection.find_one({"keyAltNames": new_key_alt_name})
546-
if not key_doc:
547-
raise ValueError(
548-
f"No key found in keyvault for keyAltName={new_key_alt_name}. "
549-
"Run with '--create-data-keys' to create missing keys."
550-
)
551-
data_key = key_doc["_id"]
552-
556+
# Entire embedded object encrypted
557+
data_key = self._get_data_key(
558+
client_encryption,
559+
key_vault_collection,
560+
create_data_keys,
561+
kms_provider,
562+
master_key,
563+
new_key_alt_name,
564+
)
553565
field_dict = {
554566
"bsonType": "object",
555567
"path": path,
556568
"keyId": data_key,
557569
}
558570
if getattr(field, "queries", False):
559571
field_dict["queries"] = field.queries
560-
561572
field_list.append(field_dict)
562573
else:
563574
# Recurse into embedded model
@@ -571,38 +582,23 @@ def _get_encrypted_fields(
571582
field_list.extend(embedded_result["fields"])
572583
continue
573584

574-
# --- Leaf encrypted field ---
585+
# Leaf encrypted field
575586
if getattr(field, "encrypted", False):
576-
if create_data_keys:
577-
if not client_encryption:
578-
raise ImproperlyConfigured("client_encryption is not configured.")
579-
data_key = client_encryption.create_data_key(
580-
kms_provider=kms_provider,
581-
master_key=master_key,
582-
key_alt_names=[new_key_alt_name],
583-
)
584-
else:
585-
if key_vault_collection is None:
586-
raise ImproperlyConfigured(
587-
f"Encrypted field {new_key_alt_name} detected "
588-
"but no key vault configured"
589-
)
590-
key_doc = key_vault_collection.find_one({"keyAltNames": new_key_alt_name})
591-
if not key_doc:
592-
raise ValueError(
593-
f"No key found in keyvault for keyAltName={new_key_alt_name}. "
594-
"Run with '--create-data-keys' to create missing keys."
595-
)
596-
data_key = key_doc["_id"]
597-
587+
data_key = self._get_data_key(
588+
client_encryption,
589+
key_vault_collection,
590+
create_data_keys,
591+
kms_provider,
592+
master_key,
593+
new_key_alt_name,
594+
)
598595
field_dict = {
599596
"bsonType": field.db_type(connection),
600597
"path": path,
601598
"keyId": data_key,
602599
}
603600
if getattr(field, "queries", False):
604601
field_dict["queries"] = field.queries
605-
606602
field_list.append(field_dict)
607603

608604
return {"fields": field_list} if field_list else None

0 commit comments

Comments
 (0)