Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8309,7 +8309,7 @@ func (a *Server) validateMFAAuthResponseInternal(
Webauthn: webConfig,
Identity: a.Services,
}
loginData, err = webLogin.Finish(ctx, user, wantypes.CredentialAssertionResponseFromProto(res.Webauthn), requiredExtensions)
loginData, err = webLogin.Finish(ctx, user, wantypes.CredentialAssertionResponseFromProto(res.Webauthn), requiredExtensions, false /* validateOnly */)
}
if err != nil {
if requiredExtensions.AllowReuse == mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_YES &&
Expand Down
1 change: 1 addition & 0 deletions lib/auth/mfa/mfav1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ func (s *Service) validateWebauthnResponse(
username,
wantypes.CredentialAssertionResponseFromProto(resp.Webauthn),
&mfav1.ChallengeExtensions{Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_USER_SESSION},
false, /* validateOnly */
)
if err != nil {
return nil, trace.AccessDenied("validate Webauthn response: %v", err)
Expand Down
19 changes: 18 additions & 1 deletion lib/auth/webauthn/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,13 @@ type LoginData struct {
TargetCluster string
}

func (f *loginFlow) finish(ctx context.Context, user string, resp *wantypes.CredentialAssertionResponse, requiredExtensions *mfav1.ChallengeExtensions) (*LoginData, error) {
func (f *loginFlow) finish(
ctx context.Context,
user string,
resp *wantypes.CredentialAssertionResponse,
requiredExtensions *mfav1.ChallengeExtensions,
validateOnly bool,
) (*LoginData, error) {
if requiredExtensions == nil {
return nil, trace.BadParameter("requested challenge extensions must be supplied.")
}
Expand Down Expand Up @@ -413,6 +419,17 @@ func (f *loginFlow) finish(ctx context.Context, user string, resp *wantypes.Cred
}
}

if validateOnly {
return &LoginData{
User: user,
Device: dev,
AllowReuse: sd.ChallengeExtensions.AllowReuse,
Payload: sd.Payload,
SourceCluster: sd.SourceCluster,
TargetCluster: sd.TargetCluster,
}, nil
}

// Update last used timestamp and device counter.
if err := updateCredentialAndTimestamps(dev, credential, discoverableLogin); err != nil {
return nil, trace.Wrap(err)
Expand Down
13 changes: 11 additions & 2 deletions lib/auth/webauthn/login_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,23 @@ func (f *LoginFlow) Begin(ctx context.Context, params BeginParams) (*wantypes.Cr
// user name, and other login properties. If login is successful, Finish has the
// side effect of updating the counter and last used timestamp of the MFADevice
// used.
func (f *LoginFlow) Finish(ctx context.Context, user string, resp *wantypes.CredentialAssertionResponse, requiredExtensions *mfav1.ChallengeExtensions) (*LoginData, error) {
// If validateOnly is true, the response will be validated without consuming it.
// This is useful for flows like Browser MFA where the browser needs to validate
// the response before returning it tsh.
func (f *LoginFlow) Finish(
ctx context.Context,
user string,
resp *wantypes.CredentialAssertionResponse,
requiredExtensions *mfav1.ChallengeExtensions,
validateOnly bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, I was OK keeping the public variants of Finish and Validate. The refactor was more about having them both delegate to the private finish func.

) (*LoginData, error) {
lf := &loginFlow{
U2F: f.U2F,
Webauthn: f.Webauthn,
identity: mfaIdentity{f.Identity},
sessionData: (*userSessionStorage)(f),
}
return lf.finish(ctx, user, resp, requiredExtensions)
return lf.finish(ctx, user, resp, requiredExtensions, validateOnly)
}

type mfaIdentity struct {
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/webauthn/login_passwordless.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (f *PasswordlessFlow) Finish(ctx context.Context, resp *wantypes.Credential
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_PASSWORDLESS_LOGIN,
AllowReuse: mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_NO,
}
return lf.finish(ctx, "" /* user */, resp, requiredExt)
return lf.finish(ctx, "" /* user */, resp, requiredExt, false /* validateOnly */)
}

type passwordlessIdentity struct {
Expand Down
112 changes: 93 additions & 19 deletions lib/auth/webauthn/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"crypto/x509"
"fmt"
"maps"
"slices"
"testing"
"time"
Expand Down Expand Up @@ -110,22 +111,25 @@ func TestLoginFlow_BeginFinish(t *testing.T) {
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
identity := test.identity
runTest := func(t *testing.T, validate bool) {
t.Parallel()

identity := test.identity.clone()
user := test.user

webLogin := &wanlib.LoginFlow{
U2F: u2fConfig,
Webauthn: webConfig,
Identity: test.identity,
Identity: identity,
}

// 1st step of the login ceremony.
chalExts := &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
}
assertion, err := webLogin.Begin(ctx, wanlib.BeginParams{
User: test.user,
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
User: test.user,
ChallengeExtensions: chalExts,
})
require.NoError(t, err)
// We care about a few specific settings, for everything else defaults are
Expand All @@ -152,11 +156,25 @@ func TestLoginFlow_BeginFinish(t *testing.T) {
assertionResp, err := test.key.SignAssertion(test.origin, assertion)
require.NoError(t, err)

if validate {
// Capture state before validation
initialCounter := getSignatureCounter(identity.User.GetLocalAuth().MFA[0])
initialSessionDataCount := len(identity.SessionData)
initialUpdatedDevicesCount := len(identity.UpdatedDevices)

// Finish(validateOnly=true) (shouldn't consume)
_, err = webLogin.Finish(ctx, user, assertionResp, chalExts, true)
require.NoError(t, err)

// Assert state unchanged
assert.Equal(t, initialCounter, getSignatureCounter(identity.User.GetLocalAuth().MFA[0]))
assert.Len(t, identity.SessionData, initialSessionDataCount)
assert.Len(t, identity.UpdatedDevices, initialUpdatedDevicesCount)
}

// 2nd and last step of the login ceremony.
beforeLastUsed := time.Now().Add(-1 * time.Second)
loginData, err := webLogin.Finish(ctx, user, assertionResp, &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
})
loginData, err := webLogin.Finish(ctx, user, assertionResp, chalExts, false /* validateOnly */)
require.NoError(t, err)
// Last used time and counter are updated.
require.True(t, beforeLastUsed.Before(loginData.Device.LastUsed))
Expand All @@ -169,7 +187,10 @@ func TestLoginFlow_BeginFinish(t *testing.T) {
}
// Did we delete the challenge?
require.Empty(t, identity.SessionData)
})
}

t.Run(test.name+"/Finish", func(t *testing.T) { runTest(t, false) })
t.Run(test.name+"/ValidateAndFinish", func(t *testing.T) { runTest(t, true) })
}
}

Expand Down Expand Up @@ -287,21 +308,37 @@ func TestLoginFlow_Finish_errors(t *testing.T) {
name string
user string
createResp func() *wantypes.CredentialAssertionResponse
exts *mfav1.ChallengeExtensions
}{
{
name: "NOK empty user",
user: "",
createResp: func() *wantypes.CredentialAssertionResponse { return okResp },
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
Comment on lines +317 to +319
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull into okExts (or some other name) and reuse throughout?

Suggested change
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
exts: okExts,

},
{
name: "NOK nil resp",
user: user,
createResp: func() *wantypes.CredentialAssertionResponse { return nil },
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
},
{
name: "NOK empty resp",
user: user,
createResp: func() *wantypes.CredentialAssertionResponse { return &wantypes.CredentialAssertionResponse{} },
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
},
{
name: "NOK nil required extensions",
user: user,
createResp: func() *wantypes.CredentialAssertionResponse { return okResp },
exts: nil,
},
{
name: "NOK assertion with bad origin",
Expand All @@ -318,6 +355,9 @@ func TestLoginFlow_Finish_errors(t *testing.T) {
require.NoError(t, err)
return resp
},
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
},
{
name: "NOK assertion with bad RPID",
Expand All @@ -336,6 +376,9 @@ func TestLoginFlow_Finish_errors(t *testing.T) {
require.NoError(t, err)
return resp
},
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
},
{
name: "NOK assertion signed by unknown device",
Expand All @@ -358,6 +401,9 @@ func TestLoginFlow_Finish_errors(t *testing.T) {
require.NoError(t, err)
return resp
},
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
},
{
name: "NOK assertion with invalid signature",
Expand All @@ -378,15 +424,19 @@ func TestLoginFlow_Finish_errors(t *testing.T) {
require.NoError(t, err)
return resp
},
exts: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := webLogin.Finish(ctx, test.user, test.createResp(), &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
})
runTest := func(t *testing.T, validateOnly bool) {
_, err := webLogin.Finish(ctx, test.user, test.createResp(), test.exts, validateOnly)
require.Error(t, err)
})
}

t.Run(test.name+"/Finish", func(t *testing.T) { runTest(t, false) })
t.Run(test.name+"/ValidateOnly", func(t *testing.T) { runTest(t, true) })
}
}

Expand Down Expand Up @@ -633,7 +683,7 @@ func TestCredentialRPID(t *testing.T) {

loginData, err := webLogin.Finish(ctx, user, car, &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_LOGIN,
})
}, false)
require.NoError(t, err, "Finish failed")
assert.Equal(t, rpID, loginData.Device.GetWebauthn().CredentialRpId, "CredentialRpId mismatch")
})
Expand Down Expand Up @@ -893,7 +943,7 @@ func TestLoginFlow_scopeAndReuse(t *testing.T) {
assertionResp, err := webKey.SignAssertion(webOrigin, assertion)
require.NoError(t, err)

loginData, err := webLogin.Finish(ctx, user, assertionResp, test.requiredExt)
loginData, err := webLogin.Finish(ctx, user, assertionResp, test.requiredExt, false)
if test.assertErr != nil {
test.assertErr(t, err)
return
Expand Down Expand Up @@ -1061,7 +1111,7 @@ func TestLoginFlow_userVerification(t *testing.T) {
assertionResp, err := test.dev.SignAssertion(origin, assertion)
require.NoError(t, err, "dev.SignAssertion")

_, err = lf.Finish(ctx, user, assertionResp, test.requiredExts)
_, err = lf.Finish(ctx, user, assertionResp, test.requiredExts, false)
if test.wantErr != "" {
assert.ErrorContains(t, err, test.wantErr, "lf.Finish error mismatch")
} else {
Expand Down Expand Up @@ -1258,6 +1308,30 @@ func newFakeIdentity(user string, devices ...*types.MFADevice) *fakeIdentity {
}
}

func (f *fakeIdentity) clone() *fakeIdentity {
// Deep copy the User
userCopy := &types.UserV2{
Metadata: f.User.Metadata,
Spec: types.UserSpecV2{
LocalAuth: &types.LocalAuthSecrets{
MFA: slices.Clone(f.User.GetLocalAuth().MFA),
Webauthn: f.User.GetLocalAuth().Webauthn,
},
},
}

// Copy SessionData map
sessionDataCopy := make(map[string]*wantypes.SessionData, len(f.SessionData))
maps.Copy(sessionDataCopy, f.SessionData)

return &fakeIdentity{
User: userCopy,
MappedUser: f.MappedUser,
UpdatedDevices: slices.Clone(f.UpdatedDevices),
SessionData: sessionDataCopy,
}
}

func (f *fakeIdentity) GetMFADevices(ctx context.Context, user string, withSecrets bool) ([]*types.MFADevice, error) {
// Return a defensive copy of the slice, the caller might modify it.
return slices.Clone(f.User.GetLocalAuth().MFA), nil
Expand Down
Loading