Skip to content

Commit 0c48ced

Browse files
elena-kolevskaItalyPaleAleyaron2
authored
state.dynamodb: validate AWS connection (dapr#3285)
Signed-off-by: Elena Kolevska <[email protected]> Signed-off-by: Elena Kolevska <[email protected]> Co-authored-by: Alessandro (Ale) Segala <[email protected]> Co-authored-by: Yaron Schneider <[email protected]>
1 parent c0a21a0 commit 0c48ced

File tree

2 files changed

+81
-21
lines changed

2 files changed

+81
-21
lines changed

state/aws/dynamodb/dynamodb.go

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ import (
2323
"strconv"
2424
"time"
2525

26-
"github.com/aws/aws-sdk-go/aws"
26+
"github.com/google/uuid"
27+
28+
"github.com/dapr/kit/ptr"
29+
2730
"github.com/aws/aws-sdk-go/service/dynamodb"
2831
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
2932
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
@@ -74,25 +77,58 @@ func NewDynamoDBStateStore(_ logger.Logger) state.Store {
7477
}
7578

7679
// Init does metadata and connection parsing.
77-
func (d *StateStore) Init(_ context.Context, metadata state.Metadata) error {
80+
func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error {
7881
meta, err := d.getDynamoDBMetadata(metadata)
7982
if err != nil {
8083
return err
8184
}
8285

83-
client, err := d.getClient(meta)
84-
if err != nil {
85-
return err
86+
// We have this check because we need to set the client to a mock in tests
87+
if d.client == nil {
88+
d.client, err = d.getClient(meta)
89+
if err != nil {
90+
return err
91+
}
8692
}
87-
88-
d.client = client
8993
d.table = meta.Table
9094
d.ttlAttributeName = meta.TTLAttributeName
9195
d.partitionKey = meta.PartitionKey
9296

97+
if err := d.validateTableAccess(ctx); err != nil {
98+
return fmt.Errorf("error validating DynamoDB table '%s' access: %w", d.table, err)
99+
}
100+
93101
return nil
94102
}
95103

104+
// validateConnection runs a dummy Get operation to validate the connection credentials,
105+
// as well as validating that the table exists, and we have access to it
106+
func (d *StateStore) validateTableAccess(ctx context.Context) error {
107+
var tableName string
108+
if random, err := uuid.NewRandom(); err == nil {
109+
tableName = random.String()
110+
} else {
111+
// We would get to this block if the entropy pool is empty.
112+
// We don't want to fail initialising Dapr because of it though,
113+
// since it's a dummy table that is only needed to check access, anyway
114+
// So we'll just use a hardcoded table name
115+
tableName = "dapr-test-table"
116+
}
117+
118+
input := &dynamodb.GetItemInput{
119+
ConsistentRead: ptr.Of(false),
120+
TableName: ptr.Of(d.table),
121+
Key: map[string]*dynamodb.AttributeValue{
122+
d.partitionKey: {
123+
S: ptr.Of(tableName),
124+
},
125+
},
126+
}
127+
128+
_, err := d.client.GetItemWithContext(ctx, input)
129+
return err
130+
}
131+
96132
// Features returns the features available in this state store.
97133
func (d *StateStore) Features() []state.Feature {
98134
// TTLs are enabled only if ttlAttributeName is set
@@ -113,11 +149,11 @@ func (d *StateStore) Features() []state.Feature {
113149
// Get retrieves a dynamoDB item.
114150
func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
115151
input := &dynamodb.GetItemInput{
116-
ConsistentRead: aws.Bool(req.Options.Consistency == state.Strong),
117-
TableName: aws.String(d.table),
152+
ConsistentRead: ptr.Of(req.Options.Consistency == state.Strong),
153+
TableName: ptr.Of(d.table),
118154
Key: map[string]*dynamodb.AttributeValue{
119155
d.partitionKey: {
120-
S: aws.String(req.Key),
156+
S: ptr.Of(req.Key),
121157
},
122158
},
123159
}
@@ -211,10 +247,10 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error
211247
input := &dynamodb.DeleteItemInput{
212248
Key: map[string]*dynamodb.AttributeValue{
213249
d.partitionKey: {
214-
S: aws.String(req.Key),
250+
S: ptr.Of(req.Key),
215251
},
216252
},
217-
TableName: aws.String(d.table),
253+
TableName: ptr.Of(d.table),
218254
}
219255

220256
if req.HasETag() {
@@ -283,19 +319,19 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb
283319

284320
item := map[string]*dynamodb.AttributeValue{
285321
d.partitionKey: {
286-
S: aws.String(req.Key),
322+
S: ptr.Of(req.Key),
287323
},
288324
"value": {
289-
S: aws.String(value),
325+
S: ptr.Of(value),
290326
},
291327
"etag": {
292-
S: aws.String(strconv.FormatUint(newEtag, 16)),
328+
S: ptr.Of(strconv.FormatUint(newEtag, 16)),
293329
},
294330
}
295331

296332
if ttl != nil {
297333
item[d.ttlAttributeName] = &dynamodb.AttributeValue{
298-
N: aws.String(strconv.FormatInt(*ttl, 10)),
334+
N: ptr.Of(strconv.FormatInt(*ttl, 10)),
299335
}
300336
}
301337

@@ -381,23 +417,23 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat
381417
return fmt.Errorf("dynamodb error: failed to marshal value for key %s: %w", req.Key, err)
382418
}
383419
twi.Put = &dynamodb.Put{
384-
TableName: aws.String(d.table),
420+
TableName: ptr.Of(d.table),
385421
Item: map[string]*dynamodb.AttributeValue{
386422
d.partitionKey: {
387-
S: aws.String(req.Key),
423+
S: ptr.Of(req.Key),
388424
},
389425
"value": {
390-
S: aws.String(value),
426+
S: ptr.Of(value),
391427
},
392428
},
393429
}
394430

395431
case state.DeleteRequest:
396432
twi.Delete = &dynamodb.Delete{
397-
TableName: aws.String(d.table),
433+
TableName: ptr.Of(d.table),
398434
Key: map[string]*dynamodb.AttributeValue{
399435
d.partitionKey: {
400-
S: aws.String(req.Key),
436+
S: ptr.Of(req.Key),
401437
},
402438
},
403439
}

state/aws/dynamodb/dynamodb_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package dynamodb
1717

1818
import (
1919
"context"
20+
"errors"
2021
"fmt"
2122
"testing"
2223
"time"
@@ -76,6 +77,12 @@ func TestInit(t *testing.T) {
7677
m := state.Metadata{}
7778
s := &StateStore{
7879
partitionKey: defaultPartitionKeyName,
80+
client: &mockedDynamoDB{
81+
// We're adding this so we can pass the connection check on Init
82+
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
83+
return nil, nil
84+
},
85+
},
7986
}
8087

8188
t.Run("NewDynamoDBStateStore Default Partition Key", func(t *testing.T) {
@@ -124,6 +131,23 @@ func TestInit(t *testing.T) {
124131
require.NoError(t, err)
125132
assert.Equal(t, s.partitionKey, pkey)
126133
})
134+
135+
t.Run("Init with bad table name or permissions", func(t *testing.T) {
136+
m.Properties = map[string]string{
137+
"Table": "does-not-exist",
138+
"Region": "eu-west-1",
139+
}
140+
141+
s.client = &mockedDynamoDB{
142+
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
143+
return nil, errors.New("Requested resource not found")
144+
},
145+
}
146+
147+
err := s.Init(context.Background(), m)
148+
require.Error(t, err)
149+
require.EqualError(t, err, "error validating DynamoDB table 'does-not-exist' access: Requested resource not found")
150+
})
127151
}
128152

129153
func TestGet(t *testing.T) {

0 commit comments

Comments
 (0)