Skip to content

Commit 0cbca21

Browse files
committed
Refactor for DRY in _get_encrypted_fields method
1 parent 437cfe0 commit 0cbca21

File tree

1 file changed

+31
-70
lines changed

1 file changed

+31
-70
lines changed

django_mongodb_backend/schema.py

Lines changed: 31 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -545,34 +545,42 @@ def _get_encrypted_fields(
545545
master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
546546
client_encryption = getattr(self.connection, "client_encryption", None)
547547

548+
def _field_dict(bson_type, path, new_key_alt_name, queries=None):
549+
"""Helper to generate a dictionary for an encrypted field.
550+
Included in parent function's scope to avoid passing parameters.
551+
"""
552+
data_key = self._get_data_key(
553+
client_encryption,
554+
key_vault_collection,
555+
create_data_keys,
556+
kms_provider,
557+
master_key,
558+
new_key_alt_name,
559+
)
560+
field_dict = {
561+
"bsonType": bson_type,
562+
"path": path,
563+
"keyId": data_key,
564+
}
565+
if queries:
566+
field_dict["queries"] = queries
567+
return field_dict
568+
548569
field_list = []
549570

550571
for field in fields:
551572
new_key_alt_name = f"{key_alt_name}.{field.column}"
552573
path = f"{path_prefix}.{field.column}" if path_prefix else field.column
553574

554-
# --- Embedded Single Document ---
555-
if isinstance(field, EmbeddedModelField):
575+
if isinstance(field, (EmbeddedModelField, EmbeddedModelArrayField)):
556576
if getattr(field, "encrypted", False):
557-
# Entire embedded object encrypted
558-
data_key = self._get_data_key(
559-
client_encryption,
560-
key_vault_collection,
561-
create_data_keys,
562-
kms_provider,
563-
master_key,
564-
new_key_alt_name,
577+
bson_type = "object" if isinstance(field, EmbeddedModelField) else "array"
578+
field_list.append(
579+
_field_dict(
580+
bson_type, path, new_key_alt_name, getattr(field, "queries", None)
581+
)
565582
)
566-
field_dict = {
567-
"bsonType": "object",
568-
"path": path,
569-
"keyId": data_key,
570-
}
571-
if getattr(field, "queries", False):
572-
field_dict["queries"] = field.queries
573-
field_list.append(field_dict)
574583
else:
575-
# Recurse into embedded model
576584
embedded_result = self._get_encrypted_fields(
577585
field.embedded_model,
578586
create_data_keys=create_data_keys,
@@ -581,58 +589,11 @@ def _get_encrypted_fields(
581589
)
582590
if embedded_result and embedded_result.get("fields"):
583591
field_list.extend(embedded_result["fields"])
584-
continue
585-
586-
# --- Array of Embedded Documents ---
587-
if isinstance(field, EmbeddedModelArrayField):
588-
if getattr(field, "encrypted", False):
589-
# Entire array contents encrypted - flat entry
590-
data_key = self._get_data_key(
591-
client_encryption,
592-
key_vault_collection,
593-
create_data_keys,
594-
kms_provider,
595-
master_key,
596-
new_key_alt_name,
597-
)
598-
field_dict = {
599-
"bsonType": "array",
600-
"path": path,
601-
"keyId": data_key,
602-
}
603-
if getattr(field, "queries", False):
604-
field_dict["queries"] = field.queries
605-
field_list.append(field_dict)
606-
else:
607-
# Recurse into embedded model for fields inside array elements
608-
embedded_result = self._get_encrypted_fields(
609-
field.embedded_model,
610-
create_data_keys=create_data_keys,
611-
key_alt_name=new_key_alt_name,
612-
path_prefix=path, # array prefix in path
613-
)
614-
if embedded_result and embedded_result.get("fields"):
615-
field_list.extend(embedded_result["fields"])
616-
continue
617-
618-
# --- Leaf encrypted field ---
619-
if getattr(field, "encrypted", False):
620-
data_key = self._get_data_key(
621-
client_encryption,
622-
key_vault_collection,
623-
create_data_keys,
624-
kms_provider,
625-
master_key,
626-
new_key_alt_name,
592+
elif getattr(field, "encrypted", False):
593+
bson_type = field.db_type(connection)
594+
field_list.append(
595+
_field_dict(bson_type, path, new_key_alt_name, getattr(field, "queries", None))
627596
)
628-
field_dict = {
629-
"bsonType": field.db_type(connection),
630-
"path": path,
631-
"keyId": data_key,
632-
}
633-
if getattr(field, "queries", False):
634-
field_dict["queries"] = field.queries
635-
field_list.append(field_dict)
636597

637598
return {"fields": field_list} if field_list else None
638599

0 commit comments

Comments
 (0)