Skip to content

Commit b940982

Browse files
Thomas StrombergThomas Stromberg
authored andcommitted
make mock race-proof, add tests, fix lint
1 parent 9765bc5 commit b940982

File tree

4 files changed

+159
-30
lines changed

4 files changed

+159
-30
lines changed

datastore.go

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// It uses only the Go standard library and makes direct REST API calls
44
// to the Datastore API. Authentication is handled via the GCP metadata
55
// server when running on GCP, or via Application Default Credentials.
6+
//
7+
//nolint:revive // Public structs required for API compatibility with cloud.google.com/go/datastore
68
package ds9
79

810
import (
@@ -21,6 +23,7 @@ import (
2123
"reflect"
2224
"strconv"
2325
"strings"
26+
"sync/atomic"
2427
"testing"
2528
"time"
2629

@@ -40,8 +43,9 @@ var (
4043
// ErrNoSuchEntity is returned when an entity is not found.
4144
ErrNoSuchEntity = errors.New("datastore: no such entity")
4245

43-
// Package-level variable for easier testing.
44-
apiURL = "https://datastore.googleapis.com/v1"
46+
// atomicAPIURL stores the API URL for thread-safe access.
47+
// Use getAPIURL() to read and setAPIURL() to write.
48+
atomicAPIURL atomic.Pointer[string]
4549

4650
httpClient = &http.Client{
4751
Timeout: defaultTimeout,
@@ -62,6 +66,22 @@ var (
6266
}
6367
)
6468

69+
//nolint:gochecknoinits // Required for thread-safe initialization of atomic pointer
70+
func init() {
71+
defaultURL := "https://datastore.googleapis.com/v1"
72+
atomicAPIURL.Store(&defaultURL)
73+
}
74+
75+
// getAPIURL returns the current API URL in a thread-safe manner.
76+
func getAPIURL() string {
77+
return *atomicAPIURL.Load()
78+
}
79+
80+
// setAPIURL sets the API URL in a thread-safe manner.
81+
func setAPIURL(url string) {
82+
atomicAPIURL.Store(&url)
83+
}
84+
6585
// SetTestURLs configures custom metadata and API URLs for testing.
6686
// This is intended for use by testing packages like ds9mock.
6787
// Returns a function that restores the original URLs.
@@ -74,11 +94,11 @@ var (
7494
// defer restore()
7595
func SetTestURLs(metadata, api string) (restore func()) {
7696
// Auth package will log warning if called outside test environment
77-
oldAPI := apiURL
78-
apiURL = api
97+
oldAPI := getAPIURL()
98+
setAPIURL(api)
7999
restoreAuth := auth.SetMetadataURL(metadata)
80100
return func() {
81-
apiURL = oldAPI
101+
setAPIURL(oldAPI)
82102
restoreAuth()
83103
}
84104
}
@@ -88,6 +108,7 @@ type Client struct {
88108
logger *slog.Logger
89109
projectID string
90110
databaseID string
111+
baseURL string // API base URL, defaults to production but can be overridden for testing
91112
}
92113

93114
// NewClient creates a new Datastore client.
@@ -122,6 +143,7 @@ func NewClientWithDatabase(ctx context.Context, projID, dbID string) (*Client, e
122143
return &Client{
123144
projectID: projID,
124145
databaseID: dbID,
146+
baseURL: getAPIURL(),
125147
logger: logger,
126148
}, nil
127149
}
@@ -288,6 +310,8 @@ func DecodeCursor(s string) (Cursor, error) {
288310

289311
// Iterator is an iterator for query results.
290312
// API compatible with cloud.google.com/go/datastore.
313+
//
314+
//nolint:govet // Field alignment optimized for API compatibility over memory layout
291315
type Iterator struct {
292316
ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore
293317
client *Client
@@ -375,7 +399,7 @@ func (it *Iterator) fetch() error {
375399
}
376400

377401
// URL-encode project ID to prevent injection attacks
378-
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(it.client.projectID))
402+
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", it.client.baseURL, neturl.PathEscape(it.client.projectID))
379403
body, err := doRequest(it.ctx, it.client.logger, reqURL, jsonData, token, it.client.projectID, it.client.databaseID)
380404
if err != nil {
381405
return err
@@ -567,7 +591,7 @@ func (c *Client) Get(ctx context.Context, key *Key, dst any) error {
567591
}
568592

569593
// URL-encode project ID to prevent injection attacks
570-
reqURL := fmt.Sprintf("%s/projects/%s:lookup", apiURL, neturl.PathEscape(c.projectID))
594+
reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID))
571595
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
572596
if err != nil {
573597
c.logger.ErrorContext(ctx, "lookup request failed", "error", err, "kind", key.Kind)
@@ -632,7 +656,7 @@ func (c *Client) Put(ctx context.Context, key *Key, src any) (*Key, error) {
632656
}
633657

634658
// URL-encode project ID to prevent injection attacks
635-
reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID))
659+
reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID))
636660
if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil {
637661
c.logger.ErrorContext(ctx, "commit request failed", "error", err, "kind", key.Kind)
638662
return nil, err
@@ -672,7 +696,7 @@ func (c *Client) Delete(ctx context.Context, key *Key) error {
672696
}
673697

674698
// URL-encode project ID to prevent injection attacks
675-
reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID))
699+
reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID))
676700
if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil {
677701
c.logger.ErrorContext(ctx, "delete request failed", "error", err, "kind", key.Kind)
678702
return err
@@ -724,7 +748,7 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error {
724748
}
725749

726750
// URL-encode project ID to prevent injection attacks
727-
reqURL := fmt.Sprintf("%s/projects/%s:lookup", apiURL, neturl.PathEscape(c.projectID))
751+
reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID))
728752
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
729753
if err != nil {
730754
c.logger.ErrorContext(ctx, "lookup request failed", "error", err)
@@ -841,7 +865,7 @@ func (c *Client) PutMulti(ctx context.Context, keys []*Key, src any) ([]*Key, er
841865
}
842866

843867
// URL-encode project ID to prevent injection attacks
844-
reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID))
868+
reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID))
845869
if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil {
846870
c.logger.ErrorContext(ctx, "commit request failed", "error", err)
847871
return nil, err
@@ -895,7 +919,7 @@ func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) error {
895919
}
896920

897921
// URL-encode project ID to prevent injection attacks
898-
reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID))
922+
reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID))
899923
if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil {
900924
c.logger.ErrorContext(ctx, "delete request failed", "error", err)
901925
return err
@@ -985,7 +1009,7 @@ func (c *Client) AllocateIDs(ctx context.Context, keys []*Key) ([]*Key, error) {
9851009
}
9861010

9871011
// URL-encode project ID to prevent injection attacks
988-
reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", apiURL, neturl.PathEscape(c.projectID))
1012+
reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", c.baseURL, neturl.PathEscape(c.projectID))
9891013
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
9901014
if err != nil {
9911015
c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err)
@@ -1172,18 +1196,18 @@ func encodeValue(v any) (any, error) {
11721196
if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array {
11731197
length := rv.Len()
11741198
values := make([]map[string]any, length)
1175-
for i := 0; i < length; i++ {
1199+
for i := range length {
11761200
elem := rv.Index(i).Interface()
11771201
encodedElem, err := encodeValue(elem)
11781202
if err != nil {
11791203
return nil, fmt.Errorf("failed to encode array element %d: %w", i, err)
11801204
}
11811205
// encodedElem is already a map[string]any with the type wrapper
1182-
if m, ok := encodedElem.(map[string]any); ok {
1183-
values[i] = m
1184-
} else {
1206+
m, ok := encodedElem.(map[string]any)
1207+
if !ok {
11851208
return nil, fmt.Errorf("unexpected encoded value type for element %d", i)
11861209
}
1210+
values[i] = m
11871211
}
11881212
return map[string]any{"arrayValue": map[string]any{"values": values}}, nil
11891213
}
@@ -1691,7 +1715,7 @@ func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) {
16911715
}
16921716

16931717
// URL-encode project ID to prevent injection attacks
1694-
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID))
1718+
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID))
16951719
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
16961720
if err != nil {
16971721
c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", q.kind)
@@ -1752,7 +1776,7 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err
17521776
}
17531777

17541778
// URL-encode project ID to prevent injection attacks
1755-
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID))
1779+
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID))
17561780
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
17571781
if err != nil {
17581782
c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", query.kind)
@@ -1846,7 +1870,7 @@ func (c *Client) Count(ctx context.Context, q *Query) (int, error) {
18461870
}
18471871

18481872
// URL-encode project ID to prevent injection attacks
1849-
reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", apiURL, neturl.PathEscape(c.projectID))
1873+
reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", c.baseURL, neturl.PathEscape(c.projectID))
18501874
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
18511875
if err != nil {
18521876
c.logger.ErrorContext(ctx, "count query failed", "error", err, "kind", q.kind)
@@ -2052,7 +2076,7 @@ func (c *Client) Mutate(ctx context.Context, muts ...*Mutation) ([]*Key, error)
20522076
}
20532077

20542078
// URL-encode project ID to prevent injection attacks
2055-
reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID))
2079+
reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID))
20562080
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
20572081
if err != nil {
20582082
c.logger.ErrorContext(ctx, "mutate request failed", "error", err)
@@ -2227,7 +2251,7 @@ func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption)
22272251
}
22282252

22292253
// URL-encode project ID to prevent injection attacks
2230-
reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", apiURL, neturl.PathEscape(c.projectID))
2254+
reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID))
22312255
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData))
22322256
if err != nil {
22332257
return nil, err
@@ -2322,7 +2346,7 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro
23222346
}
23232347

23242348
// URL-encode project ID to prevent injection attacks
2325-
reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", apiURL, neturl.PathEscape(c.projectID))
2349+
reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID))
23262350
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData))
23272351
if err != nil {
23282352
return nil, err
@@ -2447,7 +2471,7 @@ func (tx *Transaction) Get(key *Key, dst any) error {
24472471
}
24482472

24492473
// URL-encode project ID to prevent injection attacks
2450-
reqURL := fmt.Sprintf("%s/projects/%s:lookup", apiURL, neturl.PathEscape(tx.client.projectID))
2474+
reqURL := fmt.Sprintf("%s/projects/%s:lookup", tx.client.baseURL, neturl.PathEscape(tx.client.projectID))
24512475
req, err := http.NewRequestWithContext(tx.ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData))
24522476
if err != nil {
24532477
return err
@@ -2734,7 +2758,7 @@ func (tx *Transaction) doCommit(ctx context.Context, token string) error {
27342758
}
27352759

27362760
// URL-encode project ID to prevent injection attacks
2737-
reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(tx.client.projectID))
2761+
reqURL := fmt.Sprintf("%s/projects/%s:commit", tx.client.baseURL, neturl.PathEscape(tx.client.projectID))
27382762
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData))
27392763
if err != nil {
27402764
return err

datastore_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8016,8 +8016,8 @@ func TestMutate(t *testing.T) {
80168016
t.Fatalf("Mutate with no mutations failed: %v", err)
80178017
}
80188018

8019-
if keys != nil && len(keys) != 0 {
8020-
t.Errorf("Expected nil or empty keys, got %d", len(keys))
8019+
if len(keys) != 0 {
8020+
t.Errorf("Expected empty keys, got %d", len(keys))
80218021
}
80228022
})
80238023
}

ds9mock/mock.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import (
3232
const metadataFlavor = "Google"
3333

3434
// Store holds the in-memory entity storage.
35+
//
36+
//nolint:govet // Field alignment not optimized to maintain readability
3537
type Store struct {
3638
mu sync.RWMutex
3739
entities map[string]map[string]any
@@ -231,6 +233,8 @@ func (s *Store) handleLookup(w http.ResponseWriter, r *http.Request) {
231233
}
232234

233235
// handleCommit handles commit (put/delete) requests.
236+
//
237+
//nolint:gocognit,maintidx // Complex logic required for handling multiple mutation types
234238
func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) {
235239
var req struct {
236240
Mode string `json:"mode"`
@@ -553,7 +557,7 @@ func handleBeginTransaction(w http.ResponseWriter, r *http.Request) {
553557
}
554558

555559
// handleAllocateIDs handles :allocateIds requests.
556-
func (s *Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) {
560+
func (*Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) {
557561
var req struct {
558562
DatabaseID string `json:"databaseId"`
559563
Keys []map[string]any `json:"keys"`
@@ -619,6 +623,8 @@ func (s *Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) {
619623
}
620624

621625
// matchesFilter checks if an entity matches a filter.
626+
//
627+
//nolint:gocognit,nestif // Complex logic required for proper filter evaluation with multiple types and operators
622628
func matchesFilter(entity map[string]any, filterMap map[string]any) bool {
623629
// Handle propertyFilter
624630
if propFilter, ok := filterMap["propertyFilter"].(map[string]any); ok {
@@ -702,6 +708,8 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool {
702708
return ev <= fv
703709
}
704710
}
711+
default:
712+
return false
705713
}
706714
}
707715

@@ -716,7 +724,8 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool {
716724
return true
717725
}
718726

719-
if op == "AND" {
727+
switch op {
728+
case "AND":
720729
for _, f := range filters {
721730
if fm, ok := f.(map[string]any); ok {
722731
if !matchesFilter(entity, fm) {
@@ -725,7 +734,7 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool {
725734
}
726735
}
727736
return true
728-
} else if op == "OR" {
737+
case "OR":
729738
for _, f := range filters {
730739
if fm, ok := f.(map[string]any); ok {
731740
if matchesFilter(entity, fm) {
@@ -734,6 +743,8 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool {
734743
}
735744
}
736745
return false
746+
default:
747+
return true
737748
}
738749
}
739750

0 commit comments

Comments
 (0)