|
| 1 | +from bson import json_util |
| 2 | +from django.apps import apps |
| 3 | +from django.core.management.base import BaseCommand |
| 4 | +from django.db import DEFAULT_DB_ALIAS, connections, router |
| 5 | +from pymongo.encryption import ClientEncryption |
| 6 | + |
| 7 | + |
| 8 | +class Command(BaseCommand): |
| 9 | + help = "Generate a `schema_map` of encrypted fields for all encrypted" |
| 10 | + " models in the database for use with `AutoEncryptionOpts` in" |
| 11 | + " client configuration." |
| 12 | + |
| 13 | + def add_arguments(self, parser): |
| 14 | + parser.add_argument( |
| 15 | + "--database", |
| 16 | + default=DEFAULT_DB_ALIAS, |
| 17 | + help="Specify the database to use for generating the encrypted" |
| 18 | + "fields map. Defaults to the 'default' database.", |
| 19 | + ) |
| 20 | + parser.add_argument( |
| 21 | + "--kms-provider", |
| 22 | + default="local", |
| 23 | + help="Specify the KMS provider to use for encryption. Defaults to 'local'.", |
| 24 | + ) |
| 25 | + |
| 26 | + def handle(self, *args, **options): |
| 27 | + db = options["database"] |
| 28 | + kms_provider = options["kms_provider"] |
| 29 | + connection = connections[db] |
| 30 | + schema_map = json_util.dumps( |
| 31 | + self.get_encrypted_fields_map(connection, kms_provider), indent=2 |
| 32 | + ) |
| 33 | + self.stdout.write(schema_map) |
| 34 | + |
| 35 | + def get_client_encryption(self, connection): |
| 36 | + client = connection.connection |
| 37 | + options = client._options.auto_encryption_opts |
| 38 | + key_vault_namespace = options._key_vault_namespace |
| 39 | + kms_providers = options._kms_providers |
| 40 | + return ClientEncryption(kms_providers, key_vault_namespace, client, client.codec_options) |
| 41 | + |
| 42 | + def get_encrypted_fields_map(self, connection, kms_provider): |
| 43 | + schema_map = {} |
| 44 | + for app_config in apps.get_app_configs(): |
| 45 | + for model in router.get_migratable_models( |
| 46 | + app_config, connection.settings_dict["NAME"], include_auto_created=False |
| 47 | + ): |
| 48 | + if getattr(model, "encrypted", False): |
| 49 | + fields = connection.schema_editor()._get_encrypted_fields_map(model) |
| 50 | + ce = self.get_client_encryption(connection) |
| 51 | + master_key = connection.settings_dict.get("KMS_CREDENTIALS").get(kms_provider) |
| 52 | + for field in fields["fields"]: |
| 53 | + data_key = ce.create_data_key( |
| 54 | + kms_provider=kms_provider, |
| 55 | + master_key=master_key, |
| 56 | + ) |
| 57 | + field["keyId"] = data_key |
| 58 | + schema_map[model._meta.db_table] = fields |
| 59 | + return schema_map |
0 commit comments