Skip to content

Commit 4e64f88

Browse files
authored
Merge pull request #5526 from hashicorp/ddebko-backport-db-fix-to-17
backport database transaction fix
2 parents 6db3d4d + 65aa040 commit 4e64f88

28 files changed

+148
-36
lines changed

internal/auth/ldap/repository_auth_method_create.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (r *Repository) CreateAuthMethod(ctx context.Context, am *AuthMethod, opt .
6262
return nil, errors.Wrap(ctx, err, op)
6363
}
6464

65-
dbWrapper, err := r.kms.GetWrapper(context.Background(), am.ScopeId, kms.KeyPurposeDatabase)
65+
dbWrapper, err := r.kms.GetWrapper(ctx, am.ScopeId, kms.KeyPurposeDatabase)
6666
if err != nil {
6767
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
6868
}

internal/auth/oidc/repository_auth_method.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func (r *Repository) upsertAccount(ctx context.Context, am *AuthMethod, IdTokenC
179179
var rowCnt int
180180
for rows.Next() {
181181
rowCnt += 1
182-
err = r.reader.ScanRows(ctx, rows, &result)
182+
err = reader.ScanRows(ctx, rows, &result)
183183
if err != nil {
184184
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to scan rows for account"))
185185
}

internal/auth/oidc/repository_auth_method_create.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (r *Repository) CreateAuthMethod(ctx context.Context, am *AuthMethod, opt .
6161
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper"))
6262
}
6363

64-
databaseWrapper, err := r.kms.GetWrapper(context.Background(), am.ScopeId, kms.KeyPurposeDatabase)
64+
databaseWrapper, err := r.kms.GetWrapper(ctx, am.ScopeId, kms.KeyPurposeDatabase)
6565
if err != nil {
6666
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
6767
}

internal/auth/oidc/repository_managed_group_members.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/hashicorp/boundary/internal/errors"
1212
"github.com/hashicorp/boundary/internal/kms"
1313
"github.com/hashicorp/boundary/internal/oplog"
14+
"github.com/hashicorp/boundary/internal/util"
1415
)
1516

1617
// SetManagedGroupMemberships will set the managed groups for the given account
@@ -207,7 +208,7 @@ func (r *Repository) ListManagedGroupMembershipsByMember(ctx context.Context, wi
207208
limit = opts.withLimit
208209
}
209210
reader := r.reader
210-
if opts.withReader != nil {
211+
if !util.IsNil(opts.withReader) {
211212
reader = opts.withReader
212213
}
213214
var mgs []*ManagedGroupMemberAccount
@@ -232,7 +233,7 @@ func (r *Repository) ListManagedGroupMembershipsByGroup(ctx context.Context, wit
232233
limit = opts.withLimit
233234
}
234235
reader := r.reader
235-
if opts.withReader != nil {
236+
if !util.IsNil(opts.withReader) {
236237
reader = opts.withReader
237238
}
238239
var mgs []*ManagedGroupMemberAccount

internal/auth/repository_auth_method.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (amr *AuthMethodRepository) ListDeletedIds(ctx context.Context, since time.
147147
var deletedAuthMethodIDs []string
148148
var transactionTimestamp time.Time
149149
if _, err := amr.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
150-
rows, err := amr.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
150+
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
151151
if err != nil {
152152
return errors.Wrap(ctx, err, op)
153153
}

internal/credential/repository_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (s *StoreRepository) ListDeletedIds(ctx context.Context, since time.Time) (
118118
var deletedStoreIDs []string
119119
var transactionTimestamp time.Time
120120
if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
121-
rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
121+
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
122122
if err != nil {
123123
return errors.Wrap(ctx, err, op)
124124
}

internal/credential/vault/jobs.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ func nextRenewal(ctx context.Context, j scheduler.Job) (time.Duration, error) {
264264
return 0, errors.New(ctx, errors.Unknown, op, "unknown job")
265265
}
266266

267-
rows, err := r.Query(context.Background(), query, nil)
267+
rows, err := r.Query(ctx, query, nil)
268268
if err != nil {
269269
return 0, errors.Wrap(ctx, err, op)
270270
}

internal/credential/vault/vault_token.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func newToken(ctx context.Context, storeId string, token TokenSecret, accessor [
7575
accessorCopy := make([]byte, len(accessor))
7676
copy(accessorCopy, accessor)
7777

78-
hmac, err := crypto.HmacSha256WithPrk(context.Background(), tokenCopy, accessorCopy)
78+
hmac, err := crypto.HmacSha256WithPrk(ctx, tokenCopy, accessorCopy)
7979
if err != nil {
8080
return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt))
8181
}

internal/daemon/controller/handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ func wrapHandlerWithCallbackInterceptor(h http.Handler, c *Controller) http.Hand
686686

687687
if strings.HasSuffix(req.URL.Path, "oidc:authenticate") {
688688
if s, ok := values["state"].(string); ok {
689-
stateWrapper, err := oidc.UnwrapMessage(context.Background(), s)
689+
stateWrapper, err := oidc.UnwrapMessage(ctx, s)
690690
if err != nil {
691691
event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling state"))
692692
w.WriteHeader(http.StatusInternalServerError)

internal/host/options.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ package host
66
import (
77
"errors"
88

9+
"github.com/hashicorp/boundary/internal/db"
910
"github.com/hashicorp/boundary/internal/pagination"
11+
"github.com/hashicorp/boundary/internal/util"
1012
)
1113

1214
// GetOpts - iterate the inbound Options and return a struct
@@ -26,6 +28,8 @@ type Option func(*options) error
2628
// options = how options are represented
2729
type options struct {
2830
WithLimit int
31+
WithReader db.Reader
32+
WithWriter db.Writer
2933
WithOrderByCreateTime bool
3034
Ascending bool
3135
WithStartPageAfterItem pagination.Item
@@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option {
6670
return nil
6771
}
6872
}
73+
74+
// WithReaderWriter is used to share the same database reader
75+
// and writer when executing sql within a transaction.
76+
func WithReaderWriter(r db.Reader, w db.Writer) Option {
77+
return func(o *options) error {
78+
if util.IsNil(r) {
79+
return errors.New("reader cannot be nil")
80+
}
81+
if util.IsNil(w) {
82+
return errors.New("writer cannot be nil")
83+
}
84+
o.WithReader = r
85+
o.WithWriter = w
86+
return nil
87+
}
88+
}

0 commit comments

Comments
 (0)