Skip to content

Commit 92c40db

Browse files
Move middleware to its package
1 parent da082a4 commit 92c40db

File tree

3 files changed

+261
-278
lines changed

3 files changed

+261
-278
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package dbesdkmiddleware
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/aws/aws-database-encryption-sdk-dynamodb/awscryptographydbencryptionsdkdynamodbsmithygeneratedtypes"
8+
"github.com/aws/aws-database-encryption-sdk-dynamodb/awscryptographydbencryptionsdkdynamodbtransformssmithygenerated"
9+
"github.com/aws/aws-database-encryption-sdk-dynamodb/awscryptographydbencryptionsdkdynamodbtransformssmithygeneratedtypes"
10+
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
11+
"github.com/aws/smithy-go/middleware"
12+
)
13+
14+
type DBEsdkMiddleware struct {
15+
client *awscryptographydbencryptionsdkdynamodbtransformssmithygenerated.Client
16+
}
17+
18+
func NewDBEsdkMiddleware(config awscryptographydbencryptionsdkdynamodbsmithygeneratedtypes.DynamoDbTablesEncryptionConfig) (*DBEsdkMiddleware, error) {
19+
client, err := awscryptographydbencryptionsdkdynamodbtransformssmithygenerated.NewClient(config)
20+
if err != nil {
21+
return nil, err
22+
}
23+
return &DBEsdkMiddleware{
24+
client: client,
25+
}, nil
26+
}
27+
28+
func (m DBEsdkMiddleware) CreateMiddleware() func(options *dynamodb.Options) {
29+
return func(options *dynamodb.Options) {
30+
options.APIOptions = append(options.APIOptions, func(stack *middleware.Stack) error {
31+
// Add request interceptor at the beginning of Initialize step
32+
requestIntercetor := m.createRequestInterceptor()
33+
if err := stack.Initialize.Add(requestIntercetor, middleware.Before); err != nil {
34+
return err
35+
}
36+
// Add response interceptor at the end of Finalize step
37+
return stack.Finalize.Add(m.createResponseInterceptor(), middleware.After)
38+
})
39+
}
40+
}
41+
42+
func (m DBEsdkMiddleware) createRequestInterceptor() middleware.InitializeMiddleware {
43+
return middleware.InitializeMiddlewareFunc("RequestInterceptor", func(
44+
ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
45+
) (
46+
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
47+
) {
48+
ctx = m.handleRequestInterception(ctx, in.Parameters)
49+
return next.HandleInitialize(ctx, in)
50+
})
51+
}
52+
53+
// handleRequestInterception handles the interception logic before the DynamoDB operation
54+
func (m DBEsdkMiddleware) handleRequestInterception(ctx context.Context, params interface{}) context.Context {
55+
switch v := params.(type) {
56+
case *dynamodb.PutItemInput:
57+
ctx = middleware.WithStackValue(ctx, "originalInput", *deepCopyPutItemInput(v))
58+
transformedRequest, err := m.client.PutItemInputTransform(context.TODO(), awscryptographydbencryptionsdkdynamodbtransformssmithygeneratedtypes.PutItemInputTransformInput{
59+
SdkInput: *v,
60+
})
61+
if err != nil {
62+
fmt.Println(err)
63+
}
64+
*v = transformedRequest.TransformedInput
65+
case *dynamodb.GetItemInput:
66+
ctx = middleware.WithStackValue(ctx, "originalInput", *deepCopyGetItemInput(v))
67+
transformedRequest, err := m.client.GetItemInputTransform(context.TODO(), awscryptographydbencryptionsdkdynamodbtransformssmithygeneratedtypes.GetItemInputTransformInput{
68+
SdkInput: *v,
69+
})
70+
if err != nil {
71+
fmt.Println(err)
72+
}
73+
*v = transformedRequest.TransformedInput
74+
// case *dynamodb.BatchExecuteStatementInput:
75+
// m.originalRequests["BatchExecuteStatementInput"] = *DeepCopyBatchExecuteStatementInput(v)
76+
}
77+
return ctx
78+
}
79+
80+
// createResponseInterceptor creates and returns the middleware interceptor for responses
81+
func (m DBEsdkMiddleware) createResponseInterceptor() middleware.FinalizeMiddleware {
82+
return middleware.FinalizeMiddlewareFunc("ResponseInterceptor", func(
83+
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
84+
) (
85+
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
86+
) {
87+
// First let the request complete
88+
result, metadata, err := next.HandleFinalize(ctx, in)
89+
if err != nil {
90+
return result, metadata, err
91+
}
92+
// Then intercept the response
93+
m.handleResponseInterception(ctx, result.Result)
94+
return result, metadata, err
95+
})
96+
}
97+
98+
// handleResponseInterception handles the interception logic after the DynamoDB operation
99+
func (m DBEsdkMiddleware) handleResponseInterception(ctx context.Context, response interface{}) {
100+
switch v := response.(type) {
101+
case *dynamodb.PutItemOutput:
102+
transformedRequest, err := m.client.PutItemOutputTransform(context.TODO(), awscryptographydbencryptionsdkdynamodbtransformssmithygeneratedtypes.PutItemOutputTransformInput{
103+
OriginalInput: middleware.GetStackValue(ctx, "originalInput").(dynamodb.PutItemInput),
104+
SdkOutput: *v,
105+
})
106+
if err != nil {
107+
fmt.Println(err)
108+
}
109+
*v = transformedRequest.TransformedOutput
110+
case *dynamodb.GetItemOutput:
111+
transformedRequest, err := m.client.GetItemOutputTransform(context.TODO(), awscryptographydbencryptionsdkdynamodbtransformssmithygeneratedtypes.GetItemOutputTransformInput{
112+
OriginalInput: middleware.GetStackValue(ctx, "originalInput").(dynamodb.GetItemInput),
113+
SdkOutput: *v,
114+
})
115+
if err != nil {
116+
fmt.Println(err)
117+
}
118+
*v = transformedRequest.TransformedOutput
119+
}
120+
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package dbesdkmiddleware
2+
3+
import (
4+
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
5+
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
6+
)
7+
8+
// deepCopyPutItemInput performs a deep copy of a PutItemInput struct.
9+
func deepCopyPutItemInput(input *dynamodb.PutItemInput) *dynamodb.PutItemInput {
10+
if input == nil {
11+
return nil
12+
}
13+
copyItem := make(map[string]types.AttributeValue, len(input.Item))
14+
for k, v := range input.Item {
15+
copyItem[k] = deepCopyAttributeValue(v)
16+
}
17+
copyExpected := make(map[string]types.ExpectedAttributeValue, len(input.Expected))
18+
for k, v := range input.Expected {
19+
copyExpected[k] = v
20+
}
21+
copyExprNames := make(map[string]string, len(input.ExpressionAttributeNames))
22+
for k, v := range input.ExpressionAttributeNames {
23+
copyExprNames[k] = v
24+
}
25+
copyExprValues := make(map[string]types.AttributeValue, len(input.ExpressionAttributeValues))
26+
for k, v := range input.ExpressionAttributeValues {
27+
copyExprValues[k] = deepCopyAttributeValue(v)
28+
}
29+
// Copying string pointers
30+
var tableName *string
31+
if input.TableName != nil {
32+
t := *input.TableName
33+
tableName = &t
34+
}
35+
var conditionExpression *string
36+
if input.ConditionExpression != nil {
37+
ce := *input.ConditionExpression
38+
conditionExpression = &ce
39+
}
40+
return &dynamodb.PutItemInput{
41+
Item: copyItem,
42+
TableName: tableName,
43+
ConditionExpression: conditionExpression,
44+
ConditionalOperator: input.ConditionalOperator,
45+
Expected: copyExpected,
46+
ExpressionAttributeNames: copyExprNames,
47+
ExpressionAttributeValues: copyExprValues,
48+
ReturnConsumedCapacity: input.ReturnConsumedCapacity,
49+
ReturnItemCollectionMetrics: input.ReturnItemCollectionMetrics,
50+
ReturnValues: input.ReturnValues,
51+
ReturnValuesOnConditionCheckFailure: input.ReturnValuesOnConditionCheckFailure,
52+
}
53+
}
54+
55+
// deepCopyGetItemInput performs a deep copy of a GetItemInput struct.
56+
func deepCopyGetItemInput(input *dynamodb.GetItemInput) *dynamodb.GetItemInput {
57+
if input == nil {
58+
return nil
59+
}
60+
copyKey := make(map[string]types.AttributeValue, len(input.Key))
61+
for k, v := range input.Key {
62+
copyKey[k] = deepCopyAttributeValue(v)
63+
}
64+
copyExprNames := make(map[string]string, len(input.ExpressionAttributeNames))
65+
for k, v := range input.ExpressionAttributeNames {
66+
copyExprNames[k] = v
67+
}
68+
copyAttributesToGet := make([]string, len(input.AttributesToGet))
69+
copy(copyAttributesToGet, input.AttributesToGet)
70+
var tableName *string
71+
if input.TableName != nil {
72+
t := *input.TableName
73+
tableName = &t
74+
}
75+
var projectionExpression *string
76+
if input.ProjectionExpression != nil {
77+
pe := *input.ProjectionExpression
78+
projectionExpression = &pe
79+
}
80+
var consistentRead *bool
81+
if input.ConsistentRead != nil {
82+
cr := *input.ConsistentRead
83+
consistentRead = &cr
84+
}
85+
return &dynamodb.GetItemInput{
86+
Key: copyKey,
87+
TableName: tableName,
88+
AttributesToGet: copyAttributesToGet,
89+
ConsistentRead: consistentRead,
90+
ExpressionAttributeNames: copyExprNames,
91+
ProjectionExpression: projectionExpression,
92+
ReturnConsumedCapacity: input.ReturnConsumedCapacity,
93+
}
94+
}
95+
96+
// deepCopyAttributeValue performs a deep copy of AttributeValue.
97+
func deepCopyAttributeValue(attr types.AttributeValue) types.AttributeValue {
98+
switch v := attr.(type) {
99+
case *types.AttributeValueMemberS:
100+
return &types.AttributeValueMemberS{Value: v.Value}
101+
case *types.AttributeValueMemberN:
102+
return &types.AttributeValueMemberN{Value: v.Value}
103+
case *types.AttributeValueMemberB:
104+
b := make([]byte, len(v.Value))
105+
copy(b, v.Value)
106+
return &types.AttributeValueMemberB{Value: b}
107+
case *types.AttributeValueMemberBOOL:
108+
return &types.AttributeValueMemberBOOL{Value: v.Value}
109+
case *types.AttributeValueMemberNULL:
110+
return &types.AttributeValueMemberNULL{Value: v.Value}
111+
case *types.AttributeValueMemberM:
112+
newMap := make(map[string]types.AttributeValue, len(v.Value))
113+
for key, value := range v.Value {
114+
newMap[key] = deepCopyAttributeValue(value)
115+
}
116+
return &types.AttributeValueMemberM{Value: newMap}
117+
case *types.AttributeValueMemberL:
118+
newList := make([]types.AttributeValue, len(v.Value))
119+
for i, value := range v.Value {
120+
newList[i] = deepCopyAttributeValue(value)
121+
}
122+
return &types.AttributeValueMemberL{Value: newList}
123+
case *types.AttributeValueMemberSS:
124+
newSS := make([]string, len(v.Value))
125+
copy(newSS, v.Value)
126+
return &types.AttributeValueMemberSS{Value: newSS}
127+
case *types.AttributeValueMemberNS:
128+
newNS := make([]string, len(v.Value))
129+
copy(newNS, v.Value)
130+
return &types.AttributeValueMemberNS{Value: newNS}
131+
case *types.AttributeValueMemberBS:
132+
newBS := make([][]byte, len(v.Value))
133+
for i, b := range v.Value {
134+
newBS[i] = make([]byte, len(b))
135+
copy(newBS[i], b)
136+
}
137+
return &types.AttributeValueMemberBS{Value: newBS}
138+
default:
139+
panic("Unknown AttributeValue type.")
140+
}
141+
}

0 commit comments

Comments
 (0)