@@ -545,34 +545,42 @@ def _get_encrypted_fields(
545
545
master_key = connection .settings_dict .get ("KMS_CREDENTIALS" , {}).get (kms_provider )
546
546
client_encryption = getattr (self .connection , "client_encryption" , None )
547
547
548
+ def _field_dict (bson_type , path , new_key_alt_name , queries = None ):
549
+ """Helper to generate a dictionary for an encrypted field.
550
+ Included in parent function's scope to avoid passing parameters.
551
+ """
552
+ data_key = self ._get_data_key (
553
+ client_encryption ,
554
+ key_vault_collection ,
555
+ create_data_keys ,
556
+ kms_provider ,
557
+ master_key ,
558
+ new_key_alt_name ,
559
+ )
560
+ field_dict = {
561
+ "bsonType" : bson_type ,
562
+ "path" : path ,
563
+ "keyId" : data_key ,
564
+ }
565
+ if queries :
566
+ field_dict ["queries" ] = queries
567
+ return field_dict
568
+
548
569
field_list = []
549
570
550
571
for field in fields :
551
572
new_key_alt_name = f"{ key_alt_name } .{ field .column } "
552
573
path = f"{ path_prefix } .{ field .column } " if path_prefix else field .column
553
574
554
- # --- Embedded Single Document ---
555
- if isinstance (field , EmbeddedModelField ):
575
+ if isinstance (field , (EmbeddedModelField , EmbeddedModelArrayField )):
556
576
if getattr (field , "encrypted" , False ):
557
- # Entire embedded object encrypted
558
- data_key = self ._get_data_key (
559
- client_encryption ,
560
- key_vault_collection ,
561
- create_data_keys ,
562
- kms_provider ,
563
- master_key ,
564
- new_key_alt_name ,
577
+ bson_type = "object" if isinstance (field , EmbeddedModelField ) else "array"
578
+ field_list .append (
579
+ _field_dict (
580
+ bson_type , path , new_key_alt_name , getattr (field , "queries" , None )
581
+ )
565
582
)
566
- field_dict = {
567
- "bsonType" : "object" ,
568
- "path" : path ,
569
- "keyId" : data_key ,
570
- }
571
- if getattr (field , "queries" , False ):
572
- field_dict ["queries" ] = field .queries
573
- field_list .append (field_dict )
574
583
else :
575
- # Recurse into embedded model
576
584
embedded_result = self ._get_encrypted_fields (
577
585
field .embedded_model ,
578
586
create_data_keys = create_data_keys ,
@@ -581,58 +589,11 @@ def _get_encrypted_fields(
581
589
)
582
590
if embedded_result and embedded_result .get ("fields" ):
583
591
field_list .extend (embedded_result ["fields" ])
584
- continue
585
-
586
- # --- Array of Embedded Documents ---
587
- if isinstance (field , EmbeddedModelArrayField ):
588
- if getattr (field , "encrypted" , False ):
589
- # Entire array contents encrypted - flat entry
590
- data_key = self ._get_data_key (
591
- client_encryption ,
592
- key_vault_collection ,
593
- create_data_keys ,
594
- kms_provider ,
595
- master_key ,
596
- new_key_alt_name ,
597
- )
598
- field_dict = {
599
- "bsonType" : "array" ,
600
- "path" : path ,
601
- "keyId" : data_key ,
602
- }
603
- if getattr (field , "queries" , False ):
604
- field_dict ["queries" ] = field .queries
605
- field_list .append (field_dict )
606
- else :
607
- # Recurse into embedded model for fields inside array elements
608
- embedded_result = self ._get_encrypted_fields (
609
- field .embedded_model ,
610
- create_data_keys = create_data_keys ,
611
- key_alt_name = new_key_alt_name ,
612
- path_prefix = path , # array prefix in path
613
- )
614
- if embedded_result and embedded_result .get ("fields" ):
615
- field_list .extend (embedded_result ["fields" ])
616
- continue
617
-
618
- # --- Leaf encrypted field ---
619
- if getattr (field , "encrypted" , False ):
620
- data_key = self ._get_data_key (
621
- client_encryption ,
622
- key_vault_collection ,
623
- create_data_keys ,
624
- kms_provider ,
625
- master_key ,
626
- new_key_alt_name ,
592
+ elif getattr (field , "encrypted" , False ):
593
+ bson_type = field .db_type (connection )
594
+ field_list .append (
595
+ _field_dict (bson_type , path , new_key_alt_name , getattr (field , "queries" , None ))
627
596
)
628
- field_dict = {
629
- "bsonType" : field .db_type (connection ),
630
- "path" : path ,
631
- "keyId" : data_key ,
632
- }
633
- if getattr (field , "queries" , False ):
634
- field_dict ["queries" ] = field .queries
635
- field_list .append (field_dict )
636
597
637
598
return {"fields" : field_list } if field_list else None
638
599
0 commit comments