@@ -33,6 +33,8 @@ const ENCRYPT_KMS_TYPE = 'encrypt.kms.type'
3333const ENCRYPT_DEK_ALGORITHM = 'encrypt.dek.algorithm'
3434// EncryptDekExpiryDays represents dek expiry days
3535const ENCRYPT_DEK_EXPIRY_DAYS = 'encrypt.dek.expiry.days'
36+ // EncryptAlternateKmsKeyIds represents alternate kms key IDs
37+ const ENCRYPT_ALTERNATE_KMS_KEY_IDS = 'encrypt.alternate.kms.key.ids'
3638
3739// MillisInDay represents number of milliseconds in a day
3840const MILLIS_IN_DAY = 24 * 60 * 60 * 1000
@@ -387,7 +389,7 @@ export class EncryptionExecutorTransform {
387389 }
388390 let encryptedDek : Buffer | null = null
389391 if ( ! kek . shared ) {
390- kmsClient = getKmsClient ( this . executor . config ! , kek )
392+ kmsClient = new KmsClientWrapper ( this . executor . config ! , kek )
391393 // Generate new dek
392394 const rawDek = this . cryptor . generateKey ( )
393395 encryptedDek = await kmsClient . encrypt ( rawDek )
@@ -407,7 +409,7 @@ export class EncryptionExecutorTransform {
407409 const keyMaterialBytes = await this . executor . client ! . getDekKeyMaterialBytes ( dek )
408410 if ( keyMaterialBytes == null ) {
409411 if ( kmsClient == null ) {
410- kmsClient = getKmsClient ( this . executor . config ! , kek )
412+ kmsClient = new KmsClientWrapper ( this . executor . config ! , kek )
411413 }
412414 const encryptedKeyMaterialBytes = await this . executor . client ! . getDekEncryptedKeyMaterialBytes ( dek )
413415 const rawDek = await kmsClient . decrypt ( encryptedKeyMaterialBytes ! )
@@ -579,8 +581,8 @@ export class EncryptionExecutorTransform {
579581 }
580582}
581583
582- function getKmsClient ( config : Map < string , string > , kek : Kek ) : KmsClient {
583- let keyUrl = kek . kmsType + '://' + kek . kmsKeyId
584+ function getKmsClient ( config : Map < string , string > , kmsType : string , kmsKeyId : string ) : KmsClient {
585+ let keyUrl = kmsType + '://' + kmsKeyId
584586 let kmsClient = Registry . getKmsClient ( keyUrl )
585587 if ( kmsClient == null ) {
586588 let kmsDriver = Registry . getKmsDriver ( keyUrl )
@@ -641,3 +643,64 @@ export class FieldEncryptionExecutorTransform implements FieldTransform {
641643 }
642644}
643645
646+ export class KmsClientWrapper implements KmsClient {
647+ private config : Map < string , string >
648+ private kek : Kek
649+ private kekId : string
650+ private kmsKeyIds : string [ ]
651+
652+ constructor ( config : Map < string , string > , kek : Kek ) {
653+ this . config = config
654+ this . kek = kek
655+ this . kekId = kek . kmsType + '://' + kek . kmsKeyId
656+ this . kmsKeyIds = this . getKmsKeyIds ( )
657+ }
658+
659+ getKmsKeyIds ( ) : string [ ] {
660+ let kmsKeyIds = [ this . kek . kmsKeyId ! ]
661+ let alternateKmsKeyIds : string | undefined
662+ if ( this . kek . kmsProps != null ) {
663+ alternateKmsKeyIds = this . kek . kmsProps [ ENCRYPT_ALTERNATE_KMS_KEY_IDS ]
664+ }
665+ if ( alternateKmsKeyIds == null ) {
666+ alternateKmsKeyIds = this . config . get ( ENCRYPT_ALTERNATE_KMS_KEY_IDS )
667+ }
668+ if ( alternateKmsKeyIds != null ) {
669+ kmsKeyIds = kmsKeyIds . concat ( alternateKmsKeyIds . split ( ',' ) . map ( id => id . trim ( ) ) )
670+ }
671+ return kmsKeyIds
672+ }
673+
674+ supported ( keyUri : string ) : boolean {
675+ return this . kekId === keyUri
676+ }
677+
678+ async encrypt ( rawKey : Buffer ) : Promise < Buffer > {
679+ for ( let i = 0 ; i < this . kmsKeyIds . length ; i ++ ) {
680+ try {
681+ let kmsClient = getKmsClient ( this . config , this . kek . kmsType ! , this . kmsKeyIds [ i ] )
682+ return await kmsClient . encrypt ( rawKey )
683+ } catch ( e ) {
684+ if ( i === this . kmsKeyIds . length - 1 ) {
685+ throw new RuleError ( `failed to encrypt key with all KMS keys: ${ e } ` )
686+ }
687+ }
688+ }
689+ throw new RuleError ( 'no KEK found for encryption' )
690+ }
691+
692+ async decrypt ( encryptedKey : Buffer ) : Promise < Buffer > {
693+ for ( let i = 0 ; i < this . kmsKeyIds . length ; i ++ ) {
694+ try {
695+ let kmsClient = getKmsClient ( this . config , this . kek . kmsType ! , this . kmsKeyIds [ i ] )
696+ return await kmsClient . decrypt ( encryptedKey )
697+ } catch ( e ) {
698+ if ( i === this . kmsKeyIds . length - 1 ) {
699+ throw new RuleError ( `failed to decrypt key with all KMS keys: ${ e } ` )
700+ }
701+ }
702+ }
703+ throw new RuleError ( 'no KEK found for decryption' )
704+ }
705+ }
706+
0 commit comments