Skip to content

Commit e4a8a3e

Browse files
distklocJoshVanLcicoyleyaron2
authored
Ensure proper handling of binary data in DynamoDB state store (#3658)
Signed-off-by: distkloc <[email protected]> Co-authored-by: Josh van Leeuwen <[email protected]> Co-authored-by: Cassie Coyle <[email protected]> Co-authored-by: Yaron Schneider <[email protected]>
1 parent 849f139 commit e4a8a3e

File tree

2 files changed

+128
-45
lines changed

2 files changed

+128
-45
lines changed

state/aws/dynamodb/dynamodb.go

Lines changed: 84 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ type dynamoDBMetadata struct {
6161
PartitionKey string `json:"partitionKey"`
6262
}
6363

64+
type putData struct {
65+
ConditionExpression *string
66+
ExpressionAttributeValues map[string]*dynamodb.AttributeValue
67+
Item map[string]*dynamodb.AttributeValue
68+
TableName *string
69+
}
70+
6471
const (
6572
defaultPartitionKeyName = "key"
6673
metadataPartitionKey = "partitionKey"
@@ -164,9 +171,9 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get
164171
return &state.GetResponse{}, nil
165172
}
166173

167-
var output string
168-
if err = dynamodbattribute.Unmarshal(result.Item["value"], &output); err != nil {
169-
return nil, err
174+
data, err := unmarshalValue(result.Item["value"])
175+
if err != nil {
176+
return nil, fmt.Errorf("dynamodb error: failed to unmarshal value for key %s: %w", req.Key, err)
170177
}
171178

172179
var metadata map[string]string
@@ -187,7 +194,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get
187194
}
188195

189196
resp := &state.GetResponse{
190-
Data: []byte(output),
197+
Data: data,
191198
Metadata: metadata,
192199
}
193200

@@ -205,29 +212,12 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get
205212

206213
// Set saves a dynamoDB item.
207214
func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
208-
item, err := d.getItemFromReq(req)
215+
pd, err := d.createPutData(req)
209216
if err != nil {
210217
return err
211218
}
212219

213-
input := &dynamodb.PutItemInput{
214-
Item: item,
215-
TableName: &d.table,
216-
}
217-
218-
if req.HasETag() {
219-
condExpr := "etag = :etag"
220-
input.ConditionExpression = &condExpr
221-
exprAttrValues := make(map[string]*dynamodb.AttributeValue)
222-
exprAttrValues[":etag"] = &dynamodb.AttributeValue{
223-
S: req.ETag,
224-
}
225-
input.ExpressionAttributeValues = exprAttrValues
226-
} else if req.Options.Concurrency == state.FirstWrite {
227-
condExpr := "attribute_not_exists(etag)"
228-
input.ConditionExpression = &condExpr
229-
}
230-
_, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, input)
220+
_, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, pd.ToPutItemInput())
231221
if err != nil && req.HasETag() {
232222
switch cErr := err.(type) {
233223
case *dynamodb.ConditionalCheckFailedException:
@@ -292,9 +282,55 @@ func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata
292282
return &m, err
293283
}
294284

295-
// getItemFromReq converts a dapr state.SetRequest into an dynamodb item
296-
func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) {
297-
value, err := d.marshalToString(req.Value)
285+
// createPutData creates a DynamoDB put request data from a SetRequest.
286+
func (d *StateStore) createPutData(req *state.SetRequest) (putData, error) {
287+
item, err := d.createItem(req)
288+
if err != nil {
289+
return putData{}, err
290+
}
291+
292+
pd := putData{
293+
Item: item,
294+
TableName: ptr.Of(d.table),
295+
}
296+
297+
if req.HasETag() {
298+
condExpr := "etag = :etag"
299+
pd.ConditionExpression = &condExpr
300+
exprAttrValues := make(map[string]*dynamodb.AttributeValue)
301+
exprAttrValues[":etag"] = &dynamodb.AttributeValue{
302+
S: req.ETag,
303+
}
304+
pd.ExpressionAttributeValues = exprAttrValues
305+
} else if req.Options.Concurrency == state.FirstWrite {
306+
condExpr := "attribute_not_exists(etag)"
307+
pd.ConditionExpression = &condExpr
308+
}
309+
310+
return pd, nil
311+
}
312+
313+
func (d putData) ToPutItemInput() *dynamodb.PutItemInput {
314+
return &dynamodb.PutItemInput{
315+
ConditionExpression: d.ConditionExpression,
316+
ExpressionAttributeValues: d.ExpressionAttributeValues,
317+
Item: d.Item,
318+
TableName: d.TableName,
319+
}
320+
}
321+
322+
func (d putData) ToPut() *dynamodb.Put {
323+
return &dynamodb.Put{
324+
ConditionExpression: d.ConditionExpression,
325+
ExpressionAttributeValues: d.ExpressionAttributeValues,
326+
Item: d.Item,
327+
TableName: d.TableName,
328+
}
329+
}
330+
331+
// createItem creates a DynamoDB item from a SetRequest.
332+
func (d *StateStore) createItem(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) {
333+
value, err := marshalValue(req.Value)
298334
if err != nil {
299335
return nil, fmt.Errorf("dynamodb error: failed to marshal value for key %s: %w", req.Key, err)
300336
}
@@ -313,9 +349,7 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb
313349
d.partitionKey: {
314350
S: ptr.Of(req.Key),
315351
},
316-
"value": {
317-
S: ptr.Of(value),
318-
},
352+
"value": value,
319353
"etag": {
320354
S: ptr.Of(strconv.FormatUint(newEtag, 16)),
321355
},
@@ -340,12 +374,27 @@ func getRand64() (uint64, error) {
340374
return binary.LittleEndian.Uint64(randBuf), nil
341375
}
342376

343-
func (d *StateStore) marshalToString(v interface{}) (string, error) {
344-
if buf, ok := v.([]byte); ok {
345-
return string(buf), nil
377+
func marshalValue(v interface{}) (*dynamodb.AttributeValue, error) {
378+
if bt, ok := v.([]byte); ok {
379+
return &dynamodb.AttributeValue{B: bt}, nil
346380
}
347381

348-
return jsoniterator.ConfigFastest.MarshalToString(v)
382+
str, err := jsoniterator.ConfigFastest.MarshalToString(v)
383+
if err != nil {
384+
return nil, err
385+
}
386+
return &dynamodb.AttributeValue{S: ptr.Of(str)}, nil
387+
}
388+
389+
func unmarshalValue(value *dynamodb.AttributeValue) ([]byte, error) {
390+
if value == nil {
391+
return []byte(nil), nil
392+
}
393+
394+
if value.B != nil {
395+
return value.B, nil
396+
}
397+
return []byte(*value.S), nil
349398
}
350399

351400
// Parse and process ttlInSeconds.
@@ -404,21 +453,11 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat
404453
twi := &dynamodb.TransactWriteItem{}
405454
switch req := o.(type) {
406455
case state.SetRequest:
407-
value, err := d.marshalToString(req.Value)
456+
pd, err := d.createPutData(&req)
408457
if err != nil {
409458
return fmt.Errorf("dynamodb error: failed to marshal value for key %s: %w", req.Key, err)
410459
}
411-
twi.Put = &dynamodb.Put{
412-
TableName: ptr.Of(d.table),
413-
Item: map[string]*dynamodb.AttributeValue{
414-
d.partitionKey: {
415-
S: ptr.Of(req.Key),
416-
},
417-
"value": {
418-
S: ptr.Of(value),
419-
},
420-
},
421-
}
460+
twi.Put = pd.ToPut()
422461

423462
case state.DeleteRequest:
424463
twi.Delete = &dynamodb.Delete{

state/aws/dynamodb/dynamodb_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,50 @@ func TestSet(t *testing.T) {
452452
require.NoError(t, err)
453453
})
454454

455+
t.Run("Successfully set item with binary value", func(t *testing.T) {
456+
mockedDB := &awsAuth.MockDynamoDB{
457+
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
458+
assert.Equal(t, dynamodb.AttributeValue{
459+
S: aws.String("key"),
460+
}, *input.Item["key"])
461+
assert.Equal(t, dynamodb.AttributeValue{
462+
B: []byte("value"),
463+
}, *input.Item["value"])
464+
assert.Len(t, input.Item, 3)
465+
466+
return &dynamodb.PutItemOutput{
467+
Attributes: map[string]*dynamodb.AttributeValue{
468+
"key": {
469+
S: aws.String("value"),
470+
},
471+
},
472+
}, nil
473+
},
474+
}
475+
476+
dynamo := awsAuth.DynamoDBClients{
477+
DynamoDB: mockedDB,
478+
}
479+
480+
mockedClients := awsAuth.Clients{
481+
Dynamo: &dynamo,
482+
}
483+
484+
mockAuthProvider := &awsAuth.StaticAuth{}
485+
mockAuthProvider.WithMockClients(&mockedClients)
486+
s := StateStore{
487+
authProvider: mockAuthProvider,
488+
partitionKey: defaultPartitionKeyName,
489+
}
490+
491+
req := &state.SetRequest{
492+
Key: "key",
493+
Value: []byte("value"),
494+
}
495+
err := s.Set(t.Context(), req)
496+
require.NoError(t, err)
497+
})
498+
455499
t.Run("Successfully set item with matching etag", func(t *testing.T) {
456500
mockedDB := &awsAuth.MockDynamoDB{
457501
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {

0 commit comments

Comments
 (0)