Skip to content

Commit 3671259

Browse files
Thomas StrombergThomas Stromberg
authored andcommitted
Improve compatibility w/ official library
1 parent 021a1cf commit 3671259

File tree

5 files changed

+742
-40
lines changed

5 files changed

+742
-40
lines changed

datastore.go

Lines changed: 111 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ func NewClientWithDatabase(ctx context.Context, projID, dbID string) (*Client, e
116116
}, nil
117117
}
118118

119+
// Close closes the client connection.
120+
// This is a no-op for ds9 since it uses a shared HTTP client with connection pooling,
121+
// but is provided for API compatibility with cloud.google.com/go/datastore.
122+
func (*Client) Close() error {
123+
return nil
124+
}
125+
119126
// Key represents a Datastore key.
120127
type Key struct {
121128
Parent *Key // Parent key for hierarchical keys
@@ -991,6 +998,94 @@ func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) {
991998
return keys, nil
992999
}
9931000

1001+
// GetAll retrieves all entities matching the query and stores them in dst.
1002+
// dst must be a pointer to a slice of structs.
1003+
// Returns the keys of the retrieved entities and any error.
1004+
// This matches the API of cloud.google.com/go/datastore.
1005+
func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, error) {
1006+
c.logger.DebugContext(ctx, "querying for entities", "kind", query.kind, "limit", query.limit)
1007+
1008+
token, err := auth.AccessToken(ctx)
1009+
if err != nil {
1010+
c.logger.ErrorContext(ctx, "failed to get access token", "error", err)
1011+
return nil, fmt.Errorf("failed to get access token: %w", err)
1012+
}
1013+
1014+
queryObj := map[string]any{
1015+
"kind": []map[string]any{{"name": query.kind}},
1016+
}
1017+
if query.limit > 0 {
1018+
queryObj["limit"] = query.limit
1019+
}
1020+
1021+
reqBody := map[string]any{"query": queryObj}
1022+
if c.databaseID != "" {
1023+
reqBody["databaseId"] = c.databaseID
1024+
}
1025+
1026+
jsonData, err := json.Marshal(reqBody)
1027+
if err != nil {
1028+
c.logger.ErrorContext(ctx, "failed to marshal request", "error", err)
1029+
return nil, fmt.Errorf("failed to marshal request: %w", err)
1030+
}
1031+
1032+
// URL-encode project ID to prevent injection attacks
1033+
reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID))
1034+
body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID)
1035+
if err != nil {
1036+
c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", query.kind)
1037+
return nil, err
1038+
}
1039+
1040+
var result struct {
1041+
Batch struct {
1042+
EntityResults []struct {
1043+
Entity map[string]any `json:"entity"`
1044+
} `json:"entityResults"`
1045+
} `json:"batch"`
1046+
}
1047+
1048+
if err := json.Unmarshal(body, &result); err != nil {
1049+
c.logger.ErrorContext(ctx, "failed to parse response", "error", err)
1050+
return nil, fmt.Errorf("failed to parse response: %w", err)
1051+
}
1052+
1053+
// Verify dst is a pointer to slice
1054+
v := reflect.ValueOf(dst)
1055+
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice {
1056+
return nil, errors.New("dst must be a pointer to slice")
1057+
}
1058+
1059+
sliceType := v.Elem().Type()
1060+
elemType := sliceType.Elem()
1061+
1062+
// Create new slice of correct size
1063+
slice := reflect.MakeSlice(sliceType, 0, len(result.Batch.EntityResults))
1064+
keys := make([]*Key, 0, len(result.Batch.EntityResults))
1065+
1066+
for _, er := range result.Batch.EntityResults {
1067+
// Extract key
1068+
key, err := keyFromJSON(er.Entity["key"])
1069+
if err != nil {
1070+
c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err)
1071+
return nil, err
1072+
}
1073+
keys = append(keys, key)
1074+
1075+
// Decode entity
1076+
elem := reflect.New(elemType).Elem()
1077+
if err := decodeEntity(er.Entity, elem.Addr().Interface()); err != nil {
1078+
c.logger.ErrorContext(ctx, "failed to decode entity", "error", err)
1079+
return nil, err
1080+
}
1081+
slice = reflect.Append(slice, elem)
1082+
}
1083+
1084+
v.Elem().Set(slice)
1085+
c.logger.DebugContext(ctx, "query completed successfully", "kind", query.kind, "entities_found", len(keys))
1086+
return keys, nil
1087+
}
1088+
9941089
// keyFromJSON converts a JSON key representation to a Key.
9951090
func keyFromJSON(keyData any) (*Key, error) {
9961091
keyMap, ok := keyData.(map[string]any)
@@ -1031,6 +1126,10 @@ func keyFromJSON(keyData any) (*Key, error) {
10311126
return key, nil
10321127
}
10331128

1129+
// Commit represents the result of a committed transaction.
1130+
// This is provided for API compatibility with cloud.google.com/go/datastore.
1131+
type Commit struct{}
1132+
10341133
// Transaction represents a Datastore transaction.
10351134
// Note: This struct stores context for API compatibility with Google's official
10361135
// cloud.google.com/go/datastore library, which uses the same pattern.
@@ -1044,14 +1143,14 @@ type Transaction struct {
10441143
// RunInTransaction runs a function in a transaction.
10451144
// The function should use the transaction's Get and Put methods.
10461145
// API compatible with cloud.google.com/go/datastore.
1047-
func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) error) error {
1146+
func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) error) (*Commit, error) {
10481147
const maxTxRetries = 3
10491148
var lastErr error
10501149

10511150
for attempt := range maxTxRetries {
10521151
token, err := auth.AccessToken(ctx)
10531152
if err != nil {
1054-
return fmt.Errorf("failed to get access token: %w", err)
1153+
return nil, fmt.Errorf("failed to get access token: %w", err)
10551154
}
10561155

10571156
// Begin transaction
@@ -1062,14 +1161,14 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro
10621161

10631162
jsonData, err := json.Marshal(reqBody)
10641163
if err != nil {
1065-
return err
1164+
return nil, err
10661165
}
10671166

10681167
// URL-encode project ID to prevent injection attacks
10691168
reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", apiURL, neturl.PathEscape(c.projectID))
10701169
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData))
10711170
if err != nil {
1072-
return err
1171+
return nil, err
10731172
}
10741173

10751174
req.Header.Set("Authorization", "Bearer "+token)
@@ -1084,7 +1183,7 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro
10841183

10851184
resp, err := httpClient.Do(req)
10861185
if err != nil {
1087-
return err
1186+
return nil, err
10881187
}
10891188

10901189
body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize))
@@ -1093,19 +1192,19 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro
10931192
c.logger.Warn("failed to close response body", "error", closeErr)
10941193
}
10951194
if err != nil {
1096-
return err
1195+
return nil, err
10971196
}
10981197

10991198
if resp.StatusCode != http.StatusOK {
1100-
return fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body))
1199+
return nil, fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body))
11011200
}
11021201

11031202
var txResp struct {
11041203
Transaction string `json:"transaction"`
11051204
}
11061205

11071206
if err := json.Unmarshal(body, &txResp); err != nil {
1108-
return fmt.Errorf("failed to parse transaction response: %w", err)
1207+
return nil, fmt.Errorf("failed to parse transaction response: %w", err)
11091208
}
11101209

11111210
tx := &Transaction{
@@ -1117,14 +1216,14 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro
11171216
// Run the function
11181217
if err := f(tx); err != nil {
11191218
// Rollback is implicit if commit is not called
1120-
return err
1219+
return nil, err
11211220
}
11221221

11231222
// Commit the transaction
11241223
err = tx.commit(ctx, token)
11251224
if err == nil {
11261225
c.logger.Debug("transaction committed successfully", "attempt", attempt+1)
1127-
return nil // Success
1226+
return &Commit{}, nil // Success
11281227
}
11291228

11301229
c.logger.Warn("transaction commit failed", "attempt", attempt+1, "error", err)
@@ -1154,10 +1253,10 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro
11541253

11551254
// Non-retriable error
11561255
c.logger.Warn("non-retriable transaction error", "error", err)
1157-
return err
1256+
return nil, err
11581257
}
11591258

1160-
return fmt.Errorf("transaction failed after %d attempts: %w", maxTxRetries, lastErr)
1259+
return nil, fmt.Errorf("transaction failed after %d attempts: %w", maxTxRetries, lastErr)
11611260
}
11621261

11631262
// Get retrieves an entity within the transaction.

0 commit comments

Comments
 (0)