Skip to content

Commit 01a3fe7

Browse files
authored
Add custom BulkGet method to Oracle Statestore (#3804)
Signed-off-by: Anton Troshin <[email protected]>
1 parent 397766a commit 01a3fe7

File tree

5 files changed

+209
-0
lines changed

5 files changed

+209
-0
lines changed

state/oracledatabase/dbaccess.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type dbAccess interface {
2626
Ping(ctx context.Context) error
2727
Set(ctx context.Context, req *state.SetRequest) error
2828
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
29+
BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error)
2930
Delete(ctx context.Context, req *state.DeleteRequest) error
3031
ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error
3132
Close() error // io.Closer.

state/oracledatabase/oracledatabase.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ func (o *OracleDatabase) Get(ctx context.Context, req *state.GetRequest) (*state
7878
return o.dbaccess.Get(ctx, req)
7979
}
8080

81+
func (o *OracleDatabase) BulkGet(ctx context.Context, req []state.GetRequest, opts state.BulkGetOpts) ([]state.BulkGetResponse, error) {
82+
return o.dbaccess.BulkGet(ctx, req)
83+
}
84+
8185
// Set adds/updates an entity on store.
8286
func (o *OracleDatabase) Set(ctx context.Context, req *state.SetRequest) error {
8387
return o.dbaccess.Set(ctx, req)

state/oracledatabase/oracledatabase_integration_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ func TestOracleDatabaseIntegration(t *testing.T) {
120120
testBulkSetAndBulkDelete(t, ods)
121121
})
122122

123+
t.Run("Bulk get", func(t *testing.T) {
124+
testBulkGet(t, ods)
125+
})
126+
123127
t.Run("Update and delete with etag succeeds", func(t *testing.T) {
124128
updateAndDeleteWithEtagSucceeds(t, ods)
125129
})
@@ -647,6 +651,88 @@ func testBulkSetAndBulkDelete(t *testing.T, ods state.Store) {
647651
assert.False(t, storeItemExists(t, db, setReq[1].Key))
648652
}
649653

654+
func testBulkGet(t *testing.T, ods state.Store) {
655+
db := getDB(ods)
656+
657+
setReq := []state.SetRequest{
658+
{
659+
Key: randomKey(),
660+
Value: &fakeItem{Color: "red"},
661+
},
662+
{
663+
Key: randomKey(),
664+
Value: &fakeItem{Color: "blue"},
665+
},
666+
{
667+
Key: randomKey(),
668+
Value: &fakeItem{Color: "green"},
669+
},
670+
}
671+
672+
err := ods.BulkSet(t.Context(), setReq, state.BulkStoreOpts{})
673+
require.NoError(t, err)
674+
assert.True(t, storeItemExists(t, db, setReq[0].Key))
675+
assert.True(t, storeItemExists(t, db, setReq[1].Key))
676+
assert.True(t, storeItemExists(t, db, setReq[2].Key))
677+
678+
getReq := []state.GetRequest{
679+
{
680+
Key: setReq[0].Key,
681+
},
682+
{
683+
Key: setReq[1].Key,
684+
},
685+
{
686+
Key: setReq[2].Key,
687+
},
688+
{
689+
Key: randomKey(), // This key doesn't exist
690+
},
691+
}
692+
693+
responses, err := ods.BulkGet(t.Context(), getReq, state.BulkGetOpts{})
694+
require.NoError(t, err)
695+
require.Len(t, responses, 4)
696+
697+
// Verify the responses
698+
// First three items should exist
699+
for i := range 3 {
700+
assert.Equal(t, getReq[i].Key, responses[i].Key)
701+
assert.NotNil(t, responses[i].Data)
702+
assert.NotNil(t, responses[i].ETag)
703+
704+
// Verify the data
705+
var item fakeItem
706+
err = json.Unmarshal(responses[i].Data, &item)
707+
require.NoError(t, err)
708+
709+
// Check the color matches what we set
710+
originalItem := setReq[i].Value.(*fakeItem)
711+
assert.Equal(t, originalItem.Color, item.Color)
712+
}
713+
714+
// The fourth item should not exist (empty response)
715+
assert.Equal(t, getReq[3].Key, responses[3].Key)
716+
assert.Nil(t, responses[3].Data)
717+
assert.Nil(t, responses[3].ETag)
718+
719+
// Clean up
720+
deleteReq := []state.DeleteRequest{
721+
{
722+
Key: setReq[0].Key,
723+
},
724+
{
725+
Key: setReq[1].Key,
726+
},
727+
{
728+
Key: setReq[2].Key,
729+
},
730+
}
731+
732+
err = ods.BulkDelete(t.Context(), deleteReq, state.BulkStoreOpts{})
733+
require.NoError(t, err)
734+
}
735+
650736
// testInitConfiguration tests valid and invalid config settings.
651737
func testInitConfiguration(t *testing.T) {
652738
logger := logger.NewLogger("test")

state/oracledatabase/oracledatabase_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.G
6464
return nil, nil
6565
}
6666

67+
func (m *fakeDBaccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) {
68+
return []state.BulkGetResponse{}, nil
69+
}
70+
6771
func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
6872
return nil
6973
}

state/oracledatabase/oracledatabaseaccess.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,120 @@ func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (
312312
}, nil
313313
}
314314

315+
func (o *oracleDatabaseAccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) {
316+
if len(req) == 0 {
317+
return []state.BulkGetResponse{}, nil
318+
}
319+
320+
// Oracle supports the IN operator for bulk operations
321+
// Build the IN clause with bind variables
322+
// Oracle uses :1, :2, etc. for bind variables in the IN clause
323+
params := make([]any, len(req))
324+
bindVars := make([]string, len(req))
325+
for i, r := range req {
326+
if r.Key == "" {
327+
return nil, errors.New("missing key in bulk get operation")
328+
}
329+
params[i] = r.Key
330+
bindVars[i] = ":" + strconv.Itoa(i+1)
331+
}
332+
333+
inClause := strings.Join(bindVars, ",")
334+
// Concatenation is required for table name because sql.DB does not substitute parameters for table names.
335+
//nolint:gosec
336+
query := "SELECT key, value, binary_yn, etag, expiration_time FROM " + o.metadata.TableName + " WHERE key IN (" + inClause + ") AND (expiration_time IS NULL OR expiration_time > systimestamp)"
337+
338+
rows, err := o.db.QueryContext(ctx, query, params...)
339+
if err != nil {
340+
return nil, err
341+
}
342+
defer rows.Close()
343+
344+
var n int
345+
res := make([]state.BulkGetResponse, len(req))
346+
foundKeys := make(map[string]struct{}, len(req))
347+
348+
for rows.Next() {
349+
if n >= len(req) {
350+
// Sanity check to prevent panics, which should never happen
351+
return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req))
352+
}
353+
354+
var (
355+
key string
356+
value string
357+
binaryYN string
358+
etag string
359+
expireTime sql.NullTime
360+
)
361+
362+
err = rows.Scan(&key, &value, &binaryYN, &etag, &expireTime)
363+
if err != nil {
364+
res[n] = state.BulkGetResponse{
365+
Key: key,
366+
Error: err.Error(),
367+
}
368+
} else {
369+
response := state.BulkGetResponse{
370+
Key: key,
371+
ETag: &etag,
372+
}
373+
374+
if expireTime.Valid {
375+
response.Metadata = map[string]string{
376+
state.GetRespMetaKeyTTLExpireTime: expireTime.Time.UTC().Format(time.RFC3339),
377+
}
378+
}
379+
380+
if binaryYN == "Y" {
381+
var (
382+
s string
383+
data []byte
384+
)
385+
if err = json.Unmarshal([]byte(value), &s); err != nil {
386+
return nil, err
387+
}
388+
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
389+
return nil, err
390+
}
391+
response.Data = data
392+
} else {
393+
response.Data = []byte(value)
394+
}
395+
396+
res[n] = response
397+
}
398+
399+
foundKeys[key] = struct{}{}
400+
n++
401+
}
402+
403+
if err = rows.Err(); err != nil {
404+
return nil, err
405+
}
406+
407+
// Populate missing keys with empty values
408+
// This is to ensure consistency with the other state stores that implement BulkGet as a loop over Get, and with the Get method
409+
if len(foundKeys) < len(req) {
410+
var ok bool
411+
for _, r := range req {
412+
_, ok = foundKeys[r.Key]
413+
if !ok {
414+
if n >= len(req) {
415+
// Sanity check to prevent panics, which should never happen
416+
return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req))
417+
}
418+
res[n] = state.BulkGetResponse{
419+
Key: r.Key,
420+
}
421+
n++
422+
}
423+
}
424+
}
425+
426+
return res[:n], nil
427+
}
428+
315429
// Delete removes an item from the state store.
316430
func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
317431
return o.doDelete(ctx, o.db, req)

0 commit comments

Comments
 (0)