Skip to content

Commit acb3083

Browse files
committed
fix(host): pass db tx reader/writer into helper funcs
1 parent 9f40113 commit acb3083

File tree

6 files changed

+80
-6
lines changed

6 files changed

+80
-6
lines changed

internal/host/options.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ package host
66
import (
77
"errors"
88

9+
"github.com/hashicorp/boundary/internal/db"
910
"github.com/hashicorp/boundary/internal/pagination"
11+
"github.com/hashicorp/boundary/internal/util"
1012
)
1113

1214
// GetOpts - iterate the inbound Options and return a struct
@@ -26,6 +28,8 @@ type Option func(*options) error
2628
// options = how options are represented
2729
type options struct {
2830
WithLimit int
31+
WithReader db.Reader
32+
WithWriter db.Writer
2933
WithOrderByCreateTime bool
3034
Ascending bool
3135
WithStartPageAfterItem pagination.Item
@@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option {
6670
return nil
6771
}
6872
}
73+
74+
// WithReaderWriter is used to share the same database reader
75+
// and writer when executing sql within a transaction.
76+
func WithReaderWriter(r db.Reader, w db.Writer) Option {
77+
return func(o *options) error {
78+
if util.IsNil(r) {
79+
return errors.New("reader cannot be nil")
80+
}
81+
if util.IsNil(w) {
82+
return errors.New("writer cannot be nil")
83+
}
84+
o.WithReader = r
85+
o.WithWriter = w
86+
return nil
87+
}
88+
}

internal/host/options_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,23 @@ func Test_GetOpts(t *testing.T) {
7777
assert.Equal(opts.WithStartPageAfterItem.GetPublicId(), "s_1")
7878
assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
7979
})
80+
t.Run("WithReaderWriter", func(t *testing.T) {
81+
t.Parallel()
82+
t.Run("nil writer", func(t *testing.T) {
83+
t.Parallel()
84+
_, err := GetOpts(WithReaderWriter(&db.Db{}, nil))
85+
require.Error(t, err)
86+
})
87+
t.Run("nil reader", func(t *testing.T) {
88+
t.Parallel()
89+
_, err := GetOpts(WithReaderWriter(nil, &db.Db{}))
90+
require.Error(t, err)
91+
})
92+
reader := &db.Db{}
93+
writer := &db.Db{}
94+
opts, err := GetOpts(WithReaderWriter(reader, writer))
95+
require.NoError(t, err)
96+
assert.Equal(t, reader, opts.WithReader)
97+
assert.Equal(t, writer, opts.WithWriter)
98+
})
8099
}

internal/host/plugin/options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package plugin
55

66
import (
7+
"github.com/hashicorp/boundary/internal/db"
78
"github.com/hashicorp/boundary/internal/pagination"
89
"google.golang.org/protobuf/types/known/structpb"
910
)
@@ -37,6 +38,8 @@ type options struct {
3738
withSetIds []string
3839
withSecretsHmac []byte
3940
withStartPageAfterItem pagination.Item
41+
WithReader db.Reader
42+
withWriter db.Writer
4043
}
4144

4245
func getDefaultOptions() options {
@@ -153,3 +156,12 @@ func WithStartPageAfterItem(item pagination.Item) Option {
153156
o.withStartPageAfterItem = item
154157
}
155158
}
159+
160+
// WithReaderWriter is used to share the same database reader
161+
// and writer when executing sql within a transaction.
162+
func WithReaderWriter(r db.Reader, w db.Writer) Option {
163+
return func(o *options) {
164+
o.WithReader = r
165+
o.withWriter = w
166+
}
167+
}

internal/host/plugin/options_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"testing"
88
"time"
99

10+
"github.com/hashicorp/boundary/internal/db"
1011
"github.com/hashicorp/boundary/internal/db/timestamp"
1112
"github.com/hashicorp/boundary/internal/pagination"
1213
"github.com/stretchr/testify/assert"
@@ -107,4 +108,11 @@ func Test_GetOpts(t *testing.T) {
107108
assert.Equal(opts.withStartPageAfterItem.GetPublicId(), "s_1")
108109
assert.Equal(opts.withStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
109110
})
111+
t.Run("WithReaderWriter", func(t *testing.T) {
112+
reader := &db.Db{}
113+
writer := &db.Db{}
114+
opts := getOpts(WithReaderWriter(reader, writer))
115+
assert.Equal(t, reader, opts.WithReader)
116+
assert.Equal(t, writer, opts.withWriter)
117+
})
110118
}

internal/host/plugin/repository_host_catalog.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/hashicorp/boundary/internal/db"
1212
"github.com/hashicorp/boundary/internal/errors"
1313
"github.com/hashicorp/boundary/internal/event"
14+
"github.com/hashicorp/boundary/internal/host"
1415
"github.com/hashicorp/boundary/internal/kms"
1516
"github.com/hashicorp/boundary/internal/libs/patchstruct"
1617
"github.com/hashicorp/boundary/internal/oplog"
@@ -395,7 +396,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
395396
ctx,
396397
db.StdRetryCnt,
397398
db.ExpBackoff{},
398-
func(_ db.Reader, w db.Writer) error {
399+
func(read db.Reader, w db.Writer) error {
399400
msgs := make([]*oplog.Message, 0, 3)
400401
ticket, err := w.GetTicket(ctx, newCatalog)
401402
if err != nil {
@@ -517,7 +518,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
517518
if needSetSync {
518519
// We also need to mark all host sets in this catalog to be
519520
// synced as well.
520-
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId)
521+
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId, host.WithReaderWriter(read, w))
521522
if err != nil {
522523
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get sets for host catalog"))
523524
}
@@ -697,14 +698,19 @@ func (r *Repository) getCatalog(ctx context.Context, id string) (*HostCatalog, *
697698
return c, p, nil
698699
}
699700

700-
func (r *Repository) getPlugin(ctx context.Context, plgId string) (*plg.Plugin, error) {
701+
func (r *Repository) getPlugin(ctx context.Context, plgId string, opts ...Option) (*plg.Plugin, error) {
701702
const op = "plugin.(Repository).getPlugin"
702703
if plgId == "" {
703704
return nil, errors.New(ctx, errors.InvalidParameter, op, "no plugin id")
704705
}
706+
opt := getOpts(opts...)
707+
reader := r.reader
708+
if !util.IsNil(opt.WithReader) {
709+
reader = opt.WithReader
710+
}
705711
plg := plg.NewPlugin()
706712
plg.PublicId = plgId
707-
if err := r.reader.LookupByPublicId(ctx, plg); err != nil {
713+
if err := reader.LookupByPublicId(ctx, plg); err != nil {
708714
return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to get host plugin with id %q", plgId)))
709715
}
710716
return plg, nil

internal/host/plugin/repository_host_set.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,15 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
799799
limit = opts.WithLimit
800800
}
801801

802+
reader := r.reader
803+
writer := r.writer
804+
if !util.IsNil(opts.WithReader) {
805+
reader = opts.WithReader
806+
}
807+
if !util.IsNil(opts.WithWriter) {
808+
writer = opts.WithWriter
809+
}
810+
802811
args := make([]any, 0, 1)
803812
var where string
804813

@@ -820,7 +829,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
820829
}
821830

822831
var aggHostSets []*hostSetAgg
823-
if err := r.reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
832+
if err := reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
824833
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", publicId)))
825834
}
826835

@@ -839,7 +848,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
839848
}
840849
var plg *plugin.Plugin
841850
if plgId != "" {
842-
plg, err = r.getPlugin(ctx, plgId)
851+
plg, err = r.getPlugin(ctx, plgId, WithReaderWriter(reader, writer))
843852
if err != nil {
844853
return nil, nil, errors.Wrap(ctx, err, op)
845854
}

0 commit comments

Comments
 (0)