Skip to content

Commit b6849b1

Browse files
authored
fix(sess): grab kms wrapper before tx (#6052)
1 parent ea6ff60 commit b6849b1

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

internal/daemon/controller/handlers/targets/target_service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
805805
if err != nil {
806806
return nil, err
807807
}
808-
t, err := repo.LookupTargetForSessionAuthorization(ctx, roundTripTarget.GetPublicId(), target.WithAlias(targetAlias))
808+
t, err := repo.LookupTargetForSessionAuthorization(ctx, roundTripTarget.GetPublicId(), roundTripTarget.GetProjectId(), target.WithAlias(targetAlias))
809809
if err != nil {
810810
if errors.IsNotFoundError(err) {
811811
return nil, handlers.NotFoundErrorf("Target %q not found.", roundTripTarget.GetPublicId())

internal/target/repository.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,19 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms,
9393
// with its host source ids, credential source ids, and server certificate, if applicable. If the target is not
9494
// found, it will return nil, nil.
9595
// Supported option: WithAlias if the session authorization uses a target alias
96-
func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, publicId string, opt ...Option) (Target, error) {
96+
func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, publicId string, projectId string, opt ...Option) (Target, error) {
9797
const op = "target.(Repository).LookupTargetForSessionAuthorization"
9898
opts := GetOpts(opt...)
99-
100-
if publicId == "" {
99+
switch {
100+
case publicId == "":
101101
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing public id")
102+
case projectId == "":
103+
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing project id")
104+
}
105+
106+
databaseWrapper, err := r.kms.GetWrapper(ctx, projectId, kms.KeyPurposeDatabase)
107+
if err != nil {
108+
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
102109
}
103110

104111
target := allocTargetView()
@@ -107,7 +114,7 @@ func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, pu
107114
var hostSources []HostSource
108115
var credSources []CredentialSource
109116
var cert *ServerCertificate
110-
_, err := r.writer.DoTx(
117+
_, err = r.writer.DoTx(
111118
ctx,
112119
db.StdRetryCnt,
113120
db.ExpBackoff{},
@@ -133,11 +140,6 @@ func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, pu
133140
address = targetAddress.GetAddress()
134141
}
135142

136-
databaseWrapper, err := r.kms.GetWrapper(ctx, target.GetProjectId(), kms.KeyPurposeDatabase)
137-
if err != nil {
138-
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
139-
}
140-
141143
if opts.WithAlias != nil {
142144
cert, err = fetchTargetAliasProxyServerCertificate(ctx, read, w, target.PublicId, target.ProjectId, opts.WithAlias, databaseWrapper, target.GetSessionMaxSeconds())
143145
if err != nil && !errors.IsNotFoundError(err) {

internal/target/repository_proxy_server_certificate_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ func Test_FetchCertsWithinLookupTargetForSessionAuthorization(t *testing.T) {
372372
for _, tt := range tests {
373373
t.Run(tt.name, func(t *testing.T) {
374374
assert, require := assert.New(t), require.New(t)
375-
got, err := repo.LookupTargetForSessionAuthorization(ctx, tt.publicId, tt.opt...)
375+
got, err := repo.LookupTargetForSessionAuthorization(ctx, tt.publicId, proj.PublicId, tt.opt...)
376376
require.NoError(err)
377377
assert.NotNil(got)
378378
if tt.wantCert {

0 commit comments

Comments
 (0)