Skip to content

Commit 2a197ce

Browse files
authored
Refactor encryptor (#1452)
1 parent 9a57b3b commit 2a197ce

File tree

11 files changed

+439
-66
lines changed

11 files changed

+439
-66
lines changed

schemaregistry/rules/encryption/encrypt_executor.go

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ func init() {
4444
// Register registers the encryption rule executor
4545
func Register() {
4646
serde.RegisterRuleExecutor(NewExecutor())
47+
serde.RegisterRuleExecutor(NewFieldExecutor())
4748
}
4849

49-
// RegisterWithClock registers the encryption rule executor with a given clock
50-
func RegisterWithClock(c Clock) *FieldEncryptionExecutor {
50+
// RegisterExecutorWithClock registers the encryption rule executor with a given clock
51+
func RegisterExecutorWithClock(c Clock) *Executor {
5152
f := NewExecutorWithClock(c)
5253
serde.RegisterRuleExecutor(f)
5354
return f
@@ -60,10 +61,8 @@ func NewExecutor() serde.RuleExecutor {
6061
}
6162

6263
// NewExecutorWithClock creates a new encryption rule executor with a given clock
63-
func NewExecutorWithClock(c Clock) *FieldEncryptionExecutor {
64-
a := &serde.AbstractFieldRuleExecutor{}
65-
f := &FieldEncryptionExecutor{*a, nil, nil, c}
66-
f.FieldRuleExecutor = f
64+
func NewExecutorWithClock(c Clock) *Executor {
65+
f := &Executor{nil, nil, c}
6766
return f
6867
}
6968

@@ -101,16 +100,15 @@ func (*clock) NowUnixMilli() int64 {
101100
return time.Now().UnixMilli()
102101
}
103102

104-
// FieldEncryptionExecutor is a field encryption executor
105-
type FieldEncryptionExecutor struct {
106-
serde.AbstractFieldRuleExecutor
103+
// Executor is an encryption executor
104+
type Executor struct {
107105
Config map[string]string
108106
Client deks.Client
109107
Clock Clock
110108
}
111109

112110
// Configure configures the executor
113-
func (f *FieldEncryptionExecutor) Configure(clientConfig *schemaregistry.Config, config map[string]string) error {
111+
func (f *Executor) Configure(clientConfig *schemaregistry.Config, config map[string]string) error {
114112
if f.Client != nil {
115113
if !schemaregistry.ConfigsEqual(f.Client.Config(), clientConfig) {
116114
return errors.New("executor already configured")
@@ -143,12 +141,21 @@ func (f *FieldEncryptionExecutor) Configure(clientConfig *schemaregistry.Config,
143141
}
144142

145143
// Type returns the type of the executor
146-
func (f *FieldEncryptionExecutor) Type() string {
147-
return "ENCRYPT"
144+
func (f *Executor) Type() string {
145+
return "ENCRYPT_PAYLOAD"
146+
}
147+
148+
// Transform transforms the message using the rule
149+
func (f *Executor) Transform(ctx serde.RuleContext, msg interface{}) (interface{}, error) {
150+
transform, err := f.NewTransform(ctx)
151+
if err != nil {
152+
return nil, err
153+
}
154+
return transform.Transform(ctx, serde.TypeBytes, msg)
148155
}
149156

150157
// NewTransform creates a new transform
151-
func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.FieldTransform, error) {
158+
func (f *Executor) NewTransform(ctx serde.RuleContext) (*ExecutorTransform, error) {
152159
kekName, err := getKekName(ctx)
153160
if err != nil {
154161
return nil, err
@@ -157,7 +164,7 @@ func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.Fie
157164
if err != nil {
158165
return nil, err
159166
}
160-
transform := FieldEncryptionExecutorTransform{
167+
transform := ExecutorTransform{
161168
Executor: *f,
162169
Cryptor: getCryptor(ctx),
163170
KekName: kekName,
@@ -172,13 +179,13 @@ func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.Fie
172179
}
173180

174181
// Close closes the executor
175-
func (f *FieldEncryptionExecutor) Close() error {
182+
func (f *Executor) Close() error {
176183
return f.Client.Close()
177184
}
178185

179-
// FieldEncryptionExecutorTransform is a field encryption executor transform
180-
type FieldEncryptionExecutorTransform struct {
181-
Executor FieldEncryptionExecutor
186+
// ExecutorTransform is a field encryption executor transform
187+
type ExecutorTransform struct {
188+
Executor Executor
182189
Cryptor Cryptor
183190
KekName string
184191
Kek deks.Kek
@@ -290,11 +297,11 @@ func getDekExpiryDays(ctx serde.RuleContext) (int, error) {
290297
return i, nil
291298
}
292299

293-
func (f *FieldEncryptionExecutorTransform) isDekRotated() bool {
300+
func (f *ExecutorTransform) isDekRotated() bool {
294301
return f.DekExpiryDays > 0
295302
}
296303

297-
func (f *FieldEncryptionExecutorTransform) getOrCreateKek(ctx serde.RuleContext) (*deks.Kek, error) {
304+
func (f *ExecutorTransform) getOrCreateKek(ctx serde.RuleContext) (*deks.Kek, error) {
298305
isRead := ctx.RuleMode == schemaregistry.Read
299306
kekID := deks.KekID{
300307
Name: f.KekName,
@@ -334,7 +341,7 @@ func (f *FieldEncryptionExecutorTransform) getOrCreateKek(ctx serde.RuleContext)
334341
return kek, nil
335342
}
336343

337-
func (f *FieldEncryptionExecutorTransform) retrieveKekFromRegistry(key deks.KekID) (*deks.Kek, error) {
344+
func (f *ExecutorTransform) retrieveKekFromRegistry(key deks.KekID) (*deks.Kek, error) {
338345
kek, err := f.Executor.Client.GetKek(key.Name, key.Deleted)
339346
if err != nil {
340347
var restErr *rest.Error
@@ -348,7 +355,7 @@ func (f *FieldEncryptionExecutorTransform) retrieveKekFromRegistry(key deks.KekI
348355
return &kek, nil
349356
}
350357

351-
func (f *FieldEncryptionExecutorTransform) storeKekToRegistry(key deks.KekID, kmsType string, kmsKeyID string, shared bool) (*deks.Kek, error) {
358+
func (f *ExecutorTransform) storeKekToRegistry(key deks.KekID, kmsType string, kmsKeyID string, shared bool) (*deks.Kek, error) {
352359
kek, err := f.Executor.Client.RegisterKek(key.Name, kmsType, kmsKeyID, nil, "", shared)
353360
if err != nil {
354361
var restErr *rest.Error
@@ -362,7 +369,7 @@ func (f *FieldEncryptionExecutorTransform) storeKekToRegistry(key deks.KekID, km
362369
return &kek, nil
363370
}
364371

365-
func (f *FieldEncryptionExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) (*deks.Dek, error) {
372+
func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) (*deks.Dek, error) {
366373
isRead := ctx.RuleMode == schemaregistry.Read
367374
ver := 1
368375
if version != nil {
@@ -442,7 +449,7 @@ func (f *FieldEncryptionExecutorTransform) getOrCreateDek(ctx serde.RuleContext,
442449
return dek, nil
443450
}
444451

445-
func (f *FieldEncryptionExecutorTransform) createDek(dekID deks.DekID, newVersion int, encryptedDek []byte) (*deks.Dek, error) {
452+
func (f *ExecutorTransform) createDek(dekID deks.DekID, newVersion int, encryptedDek []byte) (*deks.Dek, error) {
446453
newDekID := deks.DekID{
447454
KekName: dekID.KekName,
448455
Subject: dekID.Subject,
@@ -466,7 +473,7 @@ func (f *FieldEncryptionExecutorTransform) createDek(dekID deks.DekID, newVersio
466473
return dek, nil
467474
}
468475

469-
func (f *FieldEncryptionExecutorTransform) retrieveDekFromRegistry(key deks.DekID) (*deks.Dek, error) {
476+
func (f *ExecutorTransform) retrieveDekFromRegistry(key deks.DekID) (*deks.Dek, error) {
470477
var dek deks.Dek
471478
var err error
472479
if key.Version != 0 {
@@ -486,7 +493,7 @@ func (f *FieldEncryptionExecutorTransform) retrieveDekFromRegistry(key deks.DekI
486493
return &dek, nil
487494
}
488495

489-
func (f *FieldEncryptionExecutorTransform) storeDekToRegistry(key deks.DekID, encryptedDek []byte) (*deks.Dek, error) {
496+
func (f *ExecutorTransform) storeDekToRegistry(key deks.DekID, encryptedDek []byte) (*deks.Dek, error) {
490497
var encryptedDekStr string
491498
if encryptedDek != nil {
492499
encryptedDekStr = base64.StdEncoding.EncodeToString(encryptedDek)
@@ -510,7 +517,7 @@ func (f *FieldEncryptionExecutorTransform) storeDekToRegistry(key deks.DekID, en
510517
return &dek, nil
511518
}
512519

513-
func (f *FieldEncryptionExecutorTransform) isExpired(ctx serde.RuleContext, dek *deks.Dek) bool {
520+
func (f *ExecutorTransform) isExpired(ctx serde.RuleContext, dek *deks.Dek) bool {
514521
now := f.Executor.Clock.NowUnixMilli()
515522
return ctx.RuleMode != schemaregistry.Read &&
516523
f.DekExpiryDays > 0 &&
@@ -519,15 +526,15 @@ func (f *FieldEncryptionExecutorTransform) isExpired(ctx serde.RuleContext, dek
519526
}
520527

521528
// Transform transforms the field value using the rule
522-
func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fieldCtx serde.FieldContext, fieldValue interface{}) (interface{}, error) {
529+
func (f *ExecutorTransform) Transform(ctx serde.RuleContext, fieldType serde.FieldType, fieldValue interface{}) (interface{}, error) {
523530
if fieldValue == nil {
524531
return nil, nil
525532
}
526533
switch ctx.RuleMode {
527534
case schemaregistry.Write:
528-
plaintext := toBytes(fieldCtx.Type, fieldValue)
535+
plaintext := toBytes(fieldType, fieldValue)
529536
if plaintext == nil {
530-
return nil, fmt.Errorf("type '%v' not supported for encryption", fieldCtx.Type)
537+
return nil, fmt.Errorf("type '%v' not supported for encryption", fieldType)
531538
}
532539
var version *int
533540
if f.isDekRotated() {
@@ -552,16 +559,16 @@ func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fiel
552559
return nil, err
553560
}
554561
}
555-
if fieldCtx.Type == serde.TypeString {
562+
if fieldType == serde.TypeString {
556563
return base64.StdEncoding.EncodeToString(ciphertext), nil
557564
}
558565
return ciphertext, nil
559566
case schemaregistry.Read:
560-
ciphertext := toBytes(fieldCtx.Type, fieldValue)
567+
ciphertext := toBytes(fieldType, fieldValue)
561568
if ciphertext == nil {
562569
return fieldValue, nil
563570
}
564-
if fieldCtx.Type == serde.TypeString {
571+
if fieldType == serde.TypeString {
565572
var err error
566573
ciphertext, err = base64.StdEncoding.DecodeString(string(ciphertext))
567574
if err != nil {
@@ -589,7 +596,7 @@ func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fiel
589596
if err != nil {
590597
return nil, err
591598
}
592-
return toObject(fieldCtx.Type, plaintext), nil
599+
return toObject(fieldType, plaintext), nil
593600
default:
594601
return nil, fmt.Errorf("unsupported rule mode %v", ctx.RuleMode)
595602
}

schemaregistry/rules/encryption/encrypt_executor_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
"github.com/confluentinc/confluent-kafka-go/v2/schemaregistry"
2626
)
2727

28-
func TestFieldEncryptionExecutor_Configure(t *testing.T) {
28+
func TestEncryptionExecutor_Configure(t *testing.T) {
2929
maybeFail = initFailFunc(t)
3030

3131
executor := NewExecutor()
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/**
2+
* Copyright 2024 Confluent Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package encryption
18+
19+
import (
20+
"github.com/confluentinc/confluent-kafka-go/v2/schemaregistry"
21+
"github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/serde"
22+
)
23+
24+
// RegisterFieldExecutorWithClock registers the encryption rule executor with a given clock
25+
func RegisterFieldExecutorWithClock(c Clock) *FieldEncryptionExecutor {
26+
f := NewFieldExecutorWithClock(c)
27+
serde.RegisterRuleExecutor(f)
28+
return f
29+
}
30+
31+
// NewFieldExecutor creates a new encryption rule executor
32+
func NewFieldExecutor() serde.RuleExecutor {
33+
c := clock{}
34+
return NewFieldExecutorWithClock(&c)
35+
}
36+
37+
// NewFieldExecutorWithClock creates a new encryption rule executor with a given clock
38+
func NewFieldExecutorWithClock(c Clock) *FieldEncryptionExecutor {
39+
a := &serde.AbstractFieldRuleExecutor{}
40+
f := &FieldEncryptionExecutor{*a, *NewExecutorWithClock(c)}
41+
f.FieldRuleExecutor = f
42+
return f
43+
}
44+
45+
// FieldEncryptionExecutor is a field encryption executor
46+
type FieldEncryptionExecutor struct {
47+
serde.AbstractFieldRuleExecutor
48+
Executor Executor
49+
}
50+
51+
// Configure configures the executor
52+
func (f *FieldEncryptionExecutor) Configure(clientConfig *schemaregistry.Config, config map[string]string) error {
53+
return f.Executor.Configure(clientConfig, config)
54+
}
55+
56+
// Type returns the type of the executor
57+
func (f *FieldEncryptionExecutor) Type() string {
58+
return "ENCRYPT"
59+
}
60+
61+
// NewTransform creates a new transform
62+
func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.FieldTransform, error) {
63+
executorTransform, err := f.Executor.NewTransform(ctx)
64+
if err != nil {
65+
return nil, err
66+
}
67+
transform := FieldEncryptionExecutorTransform{
68+
ExecutorTransform: *executorTransform,
69+
}
70+
return &transform, nil
71+
}
72+
73+
// Close closes the executor
74+
func (f *FieldEncryptionExecutor) Close() error {
75+
return f.Executor.Close()
76+
}
77+
78+
// FieldEncryptionExecutorTransform is a field encryption executor transform
79+
type FieldEncryptionExecutorTransform struct {
80+
ExecutorTransform ExecutorTransform
81+
}
82+
83+
// Transform transforms the field value using the rule
84+
func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fieldCtx serde.FieldContext, fieldValue interface{}) (interface{}, error) {
85+
return f.ExecutorTransform.Transform(ctx, fieldCtx.Type, fieldValue)
86+
}

schemaregistry/schemaregistry_client.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ type Rule struct {
101101
Disabled bool `json:"disabled,omitempty"`
102102
}
103103

104+
// RulePhase represents the rule phase
105+
type RulePhase = int
106+
107+
const (
108+
// MigrationPhase denotes migration phase
109+
MigrationPhase = 1
110+
// DomainPhase denotes domain phase
111+
DomainPhase = 2
112+
// EncodingPhase denotes encoding phase
113+
EncodingPhase = 3
114+
)
115+
104116
// RuleMode represents the rule mode
105117
type RuleMode = int
106118

@@ -138,25 +150,35 @@ func ParseMode(mode string) (RuleMode, bool) {
138150
type RuleSet struct {
139151
MigrationRules []Rule `json:"migrationRules,omitempty"`
140152
DomainRules []Rule `json:"domainRules,omitempty"`
153+
EncodingRules []Rule `json:"encodingRules,omitempty"`
141154
}
142155

143156
// HasRules checks if the ruleset has rules for the given mode
144-
func (r *RuleSet) HasRules(mode RuleMode) bool {
157+
func (r *RuleSet) HasRules(phase RulePhase, mode RuleMode) bool {
158+
var rules []Rule
159+
switch phase {
160+
case MigrationPhase:
161+
rules = r.MigrationRules
162+
case DomainPhase:
163+
rules = r.DomainRules
164+
case EncodingPhase:
165+
rules = r.EncodingRules
166+
}
145167
switch mode {
146168
case Upgrade, Downgrade:
147-
return r.hasRules(r.MigrationRules, func(ruleMode RuleMode) bool {
169+
return r.hasRules(rules, func(ruleMode RuleMode) bool {
148170
return ruleMode == mode || ruleMode == UpDown
149171
})
150172
case UpDown:
151-
return r.hasRules(r.MigrationRules, func(ruleMode RuleMode) bool {
173+
return r.hasRules(rules, func(ruleMode RuleMode) bool {
152174
return ruleMode == mode
153175
})
154176
case Write, Read:
155-
return r.hasRules(r.DomainRules, func(ruleMode RuleMode) bool {
177+
return r.hasRules(rules, func(ruleMode RuleMode) bool {
156178
return ruleMode == mode || ruleMode == WriteRead
157179
})
158180
case WriteRead:
159-
return r.hasRules(r.DomainRules, func(ruleMode RuleMode) bool {
181+
return r.hasRules(rules, func(ruleMode RuleMode) bool {
160182
return ruleMode == mode
161183
})
162184
}

0 commit comments

Comments
 (0)