diff --git a/examples/docker_aws_lambda_example/go.sum b/examples/docker_aws_lambda_example/go.sum index 9be073518..db85de32c 100644 --- a/examples/docker_aws_lambda_example/go.sum +++ b/examples/docker_aws_lambda_example/go.sum @@ -53,6 +53,7 @@ github.com/confluentinc/confluent-kafka-go/v2 v2.4.0 h1:NbOku86JJlsRJPJKE0snNsz6 github.com/confluentinc/confluent-kafka-go/v2 v2.4.0/go.mod h1:E1dEQy50ZLfqs7T9luxz0rLxaeFZJZE92XvApJOr/Rk= github.com/confluentinc/confluent-kafka-go/v2 v2.5.0/go.mod h1:Hyo+IIQ/tmsfkOcRP8T6VlSeOW3T33v0Me8Xvq4u90Y= github.com/confluentinc/confluent-kafka-go/v2 v2.5.3/go.mod h1:QxYLPRKR1MVlkXCCjzjjrpXb0VyFNfVaZXi0obZykJ0= +github.com/confluentinc/confluent-kafka-go/v2 v2.11.0/go.mod h1:hScqtFIGUI1wqHIgM3mjoqEou4VweGGGX7dMpcUKves= github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/containerd/containerd v1.7.12 h1:+KQsnv4VnzyxWcfO9mlxxELaoztsDEjOuCMPAuPqgU0= diff --git a/schemaregistry/rules/encryption/encrypt_executor.go b/schemaregistry/rules/encryption/encrypt_executor.go index 4c34ee920..996260e7a 100644 --- a/schemaregistry/rules/encryption/encrypt_executor.go +++ b/schemaregistry/rules/encryption/encrypt_executor.go @@ -77,6 +77,8 @@ const ( EncryptDekAlgorithm = "encrypt.dek.algorithm" // EncryptDekExpiryDays represents dek expiry days EncryptDekExpiryDays = "encrypt.dek.expiry.days" + // EncryptAlternateKmsKeyIDs represents alternate kms key ids + EncryptAlternateKmsKeyIDs = "encrypt.alternate.kms.key.ids" // Aes128Gcm represents AES128_GCM algorithm Aes128Gcm = "AES128_GCM" @@ -394,10 +396,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) } var encryptedDek []byte if !f.Kek.Shared { - primitive, err = getAead(f.Executor.Config, f.Kek) - if err != nil { - return nil, err - } + primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Executor.Config, f.Kek)} // Generate new dek keyData, err := registry.NewKeyData(f.Cryptor.KeyTemplate) if err != nil { @@ -431,10 +430,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) } if keyBytes == nil { if primitive == nil { - primitive, err = getAead(f.Executor.Config, f.Kek) - if err != nil { - return nil, err - } + primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Executor.Config, f.Kek)} } encryptedDek, err := f.Executor.Client.GetDekEncryptedKeyMaterialBytes(dek) if err != nil { @@ -629,8 +625,79 @@ func extractVersion(ciphertext []byte) (int, error) { return int(version), nil } -func getAead(config map[string]string, kek deks.Kek) (tink.AEAD, error) { - kekURL := kek.KmsType + "://" + kek.KmsKeyID +func getKmsKeyIDs(config map[string]string, kek deks.Kek) []string { + kmsKeyIDs := []string{kek.KmsKeyID} + var alternateKmsKeyIDs []string + if kek.KmsProps != nil { + if ids, ok := kek.KmsProps[EncryptAlternateKmsKeyIDs]; ok { + alternateKmsKeyIDs = strings.Split(ids, ",") + } + } + if alternateKmsKeyIDs == nil { + if ids, ok := config[EncryptAlternateKmsKeyIDs]; ok { + alternateKmsKeyIDs = strings.Split(ids, ",") + } + } + if alternateKmsKeyIDs != nil { + for _, id := range alternateKmsKeyIDs { + id = strings.TrimSpace(id) + if len(id) > 0 { + kmsKeyIDs = append(kmsKeyIDs, id) + } + } + } + return kmsKeyIDs +} + +// AeadWrapper is a wrapper for AEAD +type AeadWrapper struct { + Config map[string]string + Kek deks.Kek + KmsKeyIds []string +} + +// Encrypt encrypts plaintext with associatedData as associated data. +func (a *AeadWrapper) Encrypt(plaintext, associatedData []byte) ([]byte, error) { + var aead tink.AEAD + var err error + var ciphertext []byte + for _, kmsKeyID := range a.KmsKeyIds { + aead, err = getAead(a.Config, a.Kek.KmsType, kmsKeyID) + if err != nil { + log.Printf("WARN: failed to get AEAD with %s: %v\n", kmsKeyID, err) + continue + } + ciphertext, err = aead.Encrypt(plaintext, associatedData) + if err == nil { + return ciphertext, nil + } + log.Printf("WARN: failed to encrypt with %s: %v\n", kmsKeyID, err) + } + return nil, err +} + +// Decrypt decrypts ciphertext with associatedData as associated data. +func (a *AeadWrapper) Decrypt(ciphertext, associatedData []byte) ([]byte, error) { + var aead tink.AEAD + var err error + var plaintext []byte + for _, kmsKeyID := range a.KmsKeyIds { + aead, err = getAead(a.Config, a.Kek.KmsType, kmsKeyID) + if err != nil { + log.Printf("WARN: failed to get AEAD with %s: %v\n", kmsKeyID, err) + continue + } + plaintext, err = aead.Decrypt(ciphertext, associatedData) + if err == nil { + return plaintext, nil + } + log.Printf("WARN: failed to decrypt with %s: %v\n", kmsKeyID, err) + } + return nil, err +} + +func getAead(config map[string]string, kmsType string, kmsKeyID string) (tink.AEAD, error) { + kekURL := kmsType + "://" + kmsKeyID kmsClient, err := getKMSClient(config, kekURL) if err != nil { return nil, err diff --git a/schemaregistry/serde/avrov2/avro_test.go b/schemaregistry/serde/avrov2/avro_test.go index fae815a17..c4f3f5cd2 100644 --- a/schemaregistry/serde/avrov2/avro_test.go +++ b/schemaregistry/serde/avrov2/avro_test.go @@ -1654,6 +1654,76 @@ func TestAvroSerdePayloadEncryption(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) } +func TestAvroSerdeEncryptionAlternateKeks(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + "encrypt.alternate.kms.key.ids": "mykey2,mykey3", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT_PAYLOAD", + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + EncodingRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "AVRO", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := DemoSchema{} + obj.IntField = 123 + obj.DoubleField = 45.67 + obj.StringField = "hi" + obj.BoolField = true + obj.BytesField = []byte{1, 2} + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + deser.MessageFactory = testMessageFactory + + newobj, err := deser.Deserialize("topic1", bytes) + serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) +} + func TestAvroSerdeEncryptionDeterministic(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error