Skip to content

Commit fc6c107

Browse files
Merge pull request #600 from openziti/fix.599.auth.deadlock
Fix.599.auth.deadlock
2 parents 5256b59 + 85bd0a6 commit fc6c107

File tree

4 files changed

+125
-35
lines changed

4 files changed

+125
-35
lines changed

edge-apis/authwrapper.go

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ type ApiSession interface {
7878

7979
//RequiresRouterTokenUpdate returns true if the token is a bearer token requires updating on edge router connections.
8080
RequiresRouterTokenUpdate() bool
81+
82+
GetRequestHeaders() http.Header
8183
}
8284

8385
var _ ApiSession = (*ApiSessionLegacy)(nil)
@@ -86,7 +88,12 @@ var _ ApiSession = (*ApiSessionOidc)(nil)
8688
// ApiSessionLegacy represents OpenZiti's original authentication API Session Detail, supplied in the `zt-session` header.
8789
// It has been supplanted by OIDC authentication represented by ApiSessionOidc.
8890
type ApiSessionLegacy struct {
89-
Detail *rest_model.CurrentAPISessionDetail
91+
Detail *rest_model.CurrentAPISessionDetail
92+
RequestHeaders http.Header
93+
}
94+
95+
func (a *ApiSessionLegacy) GetRequestHeaders() http.Header {
96+
return a.RequestHeaders
9097
}
9198

9299
func (a *ApiSessionLegacy) RequiresRouterTokenUpdate() bool {
@@ -119,8 +126,15 @@ func (a *ApiSessionLegacy) AuthenticateRequest(request runtime.ClientRequest, _
119126
return errors.New("api session is nil")
120127
}
121128

122-
header, val := a.GetAccessHeader()
129+
for h, v := range a.RequestHeaders {
130+
err := request.SetHeaderParam(h, v...)
131+
if err != nil {
132+
return err
133+
}
134+
}
123135

136+
//legacy does not support multiple zt-session headers, so we can it sfely
137+
header, val := a.GetAccessHeader()
124138
err := request.SetHeaderParam(header, val)
125139
if err != nil {
126140
return err
@@ -151,7 +165,12 @@ func (a *ApiSessionLegacy) GetExpiresAt() *time.Time {
151165

152166
// ApiSessionOidc represents an authenticated session backed by OIDC tokens.
153167
type ApiSessionOidc struct {
154-
OidcTokens *oidc.Tokens[*oidc.IDTokenClaims]
168+
OidcTokens *oidc.Tokens[*oidc.IDTokenClaims]
169+
RequestHeaders http.Header
170+
}
171+
172+
func (a *ApiSessionOidc) GetRequestHeaders() http.Header {
173+
return a.RequestHeaders
155174
}
156175

157176
func (a *ApiSessionOidc) RequiresRouterTokenUpdate() bool {
@@ -203,9 +222,31 @@ func (a *ApiSessionOidc) AuthenticateRequest(request runtime.ClientRequest, _ st
203222
return errors.New("api session is nil")
204223
}
205224

206-
header, val := a.GetAccessHeader()
225+
if a.RequestHeaders == nil {
226+
a.RequestHeaders = http.Header{}
227+
}
228+
229+
//multiple Authorization headers are allowed, obtain all auth header candidates
230+
primaryAuthHeader, primaryAuthValue := a.GetAccessHeader()
231+
altAuthValues := a.RequestHeaders.Get(primaryAuthHeader)
232+
233+
authValues := []string{primaryAuthValue}
234+
235+
if len(altAuthValues) > 0 {
236+
authValues = append(authValues, altAuthValues)
237+
}
238+
239+
//set request headers
240+
for h, v := range a.RequestHeaders {
241+
err := request.SetHeaderParam(h, v...)
242+
if err != nil {
243+
return err
244+
}
245+
}
246+
247+
//restore auth headers
248+
err := request.SetHeaderParam(primaryAuthHeader, authValues...)
207249

208-
err := request.SetHeaderParam(header, val)
209250
if err != nil {
210251
return err
211252
}
@@ -320,7 +361,9 @@ func (self *ZitiEdgeManagement) legacyAuth(credentials Credentials, configTypes
320361
return nil, err
321362
}
322363

323-
return &ApiSessionLegacy{Detail: resp.GetPayload().Data}, err
364+
return &ApiSessionLegacy{
365+
Detail: resp.GetPayload().Data,
366+
RequestHeaders: credentials.GetRequestHeaders()}, err
324367
}
325368

326369
func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) {
@@ -355,7 +398,8 @@ func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession, httpCli
355398
}
356399

357400
return &ApiSessionOidc{
358-
OidcTokens: tokens,
401+
OidcTokens: tokens,
402+
RequestHeaders: apiSession.GetRequestHeaders(),
359403
}, nil
360404
}
361405

@@ -453,7 +497,7 @@ func (self *ZitiEdgeClient) legacyAuth(credentials Credentials, configTypes []st
453497
return nil, err
454498
}
455499

456-
return &ApiSessionLegacy{Detail: resp.GetPayload().Data}, err
500+
return &ApiSessionLegacy{Detail: resp.GetPayload().Data, RequestHeaders: credentials.GetRequestHeaders()}, err
457501
}
458502

459503
func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) {
@@ -480,7 +524,8 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient
480524
}
481525

482526
newApiSession := &ApiSessionLegacy{
483-
Detail: newApiSessionDetail.Payload.Data,
527+
Detail: newApiSessionDetail.Payload.Data,
528+
RequestHeaders: apiSession.GetRequestHeaders(),
484529
}
485530

486531
return newApiSession, nil
@@ -492,7 +537,8 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient
492537
}
493538

494539
return &ApiSessionOidc{
495-
OidcTokens: tokens,
540+
OidcTokens: tokens,
541+
RequestHeaders: apiSession.GetRequestHeaders(),
496542
}, nil
497543
}
498544

@@ -748,7 +794,8 @@ func oidcAuth(clientTransportPool ClientTransportPool, credentials Credentials,
748794
}
749795

750796
return &ApiSessionOidc{
751-
OidcTokens: outTokens,
797+
OidcTokens: outTokens,
798+
RequestHeaders: credentials.GetRequestHeaders(),
752799
}, nil
753800
}
754801

edge-apis/credentials.go

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ type Credentials interface {
2828
// Method returns the authentication necessary to complete an authentication request.
2929
Method() string
3030

31-
// AddHeader adds a header to the request.
32-
AddHeader(key, value string)
31+
// AddAuthHeader adds a header for all authentication requests.
32+
AddAuthHeader(key, value string)
33+
34+
// AddRequestHeader adds a header for all requests after authentication
35+
AddRequestHeader(key, value string)
3336

3437
// AddJWT adds additional JWTs to the credentials. Used to satisfy secondary authentication/MFA requirements. The
3538
// provided token should be the base64 encoded version of the token.
@@ -38,6 +41,9 @@ type Credentials interface {
3841
// ClientAuthInfoWriter is used to pass a Credentials instance to the openapi runtime to authenticate outgoing
3942
//requests.
4043
runtime.ClientAuthInfoWriter
44+
45+
// GetRequestHeaders returns a set of headers to use after authentication during normal HTTP operations
46+
GetRequestHeaders() http.Header
4147
}
4248

4349
// IdentityProvider is a sentinel interface used to determine whether the backing Credentials instance can provide
@@ -83,8 +89,11 @@ type BaseCredentials struct {
8389
// ConfigTypes is used to set the configuration types for services during authentication
8490
ConfigTypes []string
8591

86-
// Headers is a map of strings to string arrays of headers to send with auth requests.
87-
Headers *http.Header
92+
// AuthHeaders is a map of strings to string arrays of headers to send with auth requests.
93+
AuthHeaders http.Header
94+
95+
// RequestHeaders is a map of string to string arrays of headers to send on non-authentication requests.
96+
RequestHeaders http.Header
8897

8998
// EnvInfo is provided during authentication to set environmental information about the client.
9099
EnvInfo *rest_model.EnvInfo
@@ -121,41 +130,75 @@ func (c *BaseCredentials) GetCaPool() *x509.CertPool {
121130
return c.CaPool
122131
}
123132

124-
// AddHeader provides a base implementation to add a header to the request.
125-
func (c *BaseCredentials) AddHeader(key, value string) {
126-
if c.Headers == nil {
127-
c.Headers = &http.Header{}
133+
// AddAuthHeader provides a base implementation to add a header to authentication requests.
134+
func (c *BaseCredentials) AddAuthHeader(key, value string) {
135+
if c.AuthHeaders == nil {
136+
c.AuthHeaders = http.Header{}
128137
}
129-
c.Headers.Add(key, value)
138+
c.AuthHeaders.Add(key, value)
139+
}
140+
141+
// AddRequestHeader provides a base implementation to add a header to all requests after authentication.
142+
func (c *BaseCredentials) AddRequestHeader(key, value string) {
143+
if c.RequestHeaders == nil {
144+
c.RequestHeaders = http.Header{}
145+
}
146+
147+
c.RequestHeaders.Add(key, value)
130148
}
131149

132150
// AddJWT adds additional JWTs to the credentials. Used to satisfy secondary authentication/MFA requirements. The
133151
// provided token should be the base64 encoded version of the token. Convenience function for AddHeader.
134152
func (c *BaseCredentials) AddJWT(token string) {
135-
c.AddHeader("Authorization", "Bearer "+token)
153+
c.AddAuthHeader("Authorization", "Bearer "+token)
154+
c.AddRequestHeader("Authorization", "Bearer "+token)
136155
}
137156

138157
// AuthenticateRequest provides a base implementation to authenticate an outgoing request. This is provided here
139158
// for authentication methods such as `cert` which do not have to provide any more request level information.
140159
func (c *BaseCredentials) AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error {
141160
var errors []error
142161

143-
if c.Headers != nil {
144-
for hName, hVals := range *c.Headers {
145-
for _, hVal := range hVals {
146-
err := request.SetHeaderParam(hName, hVal)
147-
if err != nil {
148-
errors = append(errors, err)
149-
}
162+
for hName, hVals := range c.AuthHeaders {
163+
for _, hVal := range hVals {
164+
err := request.SetHeaderParam(hName, hVal)
165+
if err != nil {
166+
errors = append(errors, err)
167+
}
168+
}
169+
}
170+
171+
if len(errors) > 0 {
172+
return network.MultipleErrors(errors)
173+
}
174+
return nil
175+
}
176+
177+
// ProcessRequest proves a base implemmentation mutate runtime.ClientRequests as they are sent out after
178+
// authentication. Useful for adding headers.
179+
func (c *BaseCredentials) ProcessRequest(request runtime.ClientRequest, _ strfmt.Registry) error {
180+
var errors []error
181+
182+
for hName, hVals := range c.RequestHeaders {
183+
for _, hVal := range hVals {
184+
err := request.SetHeaderParam(hName, hVal)
185+
if err != nil {
186+
errors = append(errors, err)
150187
}
151188
}
152189
}
190+
153191
if len(errors) > 0 {
154192
return network.MultipleErrors(errors)
155193
}
156194
return nil
157195
}
158196

197+
// GetRequestHeaders returns headers that should be sent on requests post authentication.
198+
func (c *BaseCredentials) GetRequestHeaders() http.Header {
199+
return c.RequestHeaders
200+
}
201+
159202
// TlsCerts provides a base implementation of returning the tls.Certificate array that will be used to setup
160203
// mTLS connections. This is provided here for authentication methods that do not initially require mTLS (e.g. JWTs).
161204
func (c *BaseCredentials) TlsCerts() []tls.Certificate {

ziti/sdkinfo/build_info.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ziti/ziti.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ type Context interface {
150150
Close()
151151

152152
// Deprecated: AddZitiMfaHandler adds a Ziti MFA handler, invoked during authentication.
153-
// Replaced with event functionality. Use `zitiContext.AddListener(MfaTotpCode, handler)` instead.
153+
// Replaced with event functionality. Use `zitiContext.Events().AddMfaTotpCodeListener(func(Context, *rest_model.AuthQueryDetail, MfaCodeResponse))` instead.
154154
AddZitiMfaHandler(handler func(query *rest_model.AuthQueryDetail, resp MfaCodeResponse) error)
155155

156156
// EnrollZitiMfa will attempt to enable TOTP 2FA on the currently authenticating identity if not already enrolled.
@@ -193,7 +193,6 @@ type ContextImpl struct {
193193
authQueryHandlers map[string]func(query *rest_model.AuthQueryDetail, response MfaCodeResponse) error
194194

195195
events.EventEmmiter
196-
apiSessionLock sync.Mutex
197196
lastSuccessfulApiSessionRefresh time.Time
198197
}
199198

@@ -928,9 +927,6 @@ func (context *ContextImpl) Reauthenticate() error {
928927
}
929928

930929
func (context *ContextImpl) Authenticate() error {
931-
context.apiSessionLock.Lock()
932-
defer context.apiSessionLock.Unlock()
933-
934930
if context.CtrlClt.GetCurrentApiSession() != nil {
935931
if time.Since(context.lastSuccessfulApiSessionRefresh) < 5*time.Second {
936932
return nil
@@ -1040,6 +1036,10 @@ func (context *ContextImpl) authenticateMfa(code string) error {
10401036
func (context *ContextImpl) handleAuthQuery(authQuery *rest_model.AuthQueryDetail) error {
10411037
context.Emit(EventAuthQuery, authQuery)
10421038

1039+
if authQuery.Provider == nil {
1040+
return fmt.Errorf("unhandled response from controller: authentication query has no provider specified")
1041+
}
1042+
10431043
if *authQuery.Provider == rest_model.MfaProvidersZiti {
10441044
handler := context.authQueryHandlers[string(rest_model.MfaProvidersZiti)]
10451045

@@ -1054,7 +1054,7 @@ func (context *ContextImpl) handleAuthQuery(authQuery *rest_model.AuthQueryDetai
10541054
return nil
10551055
}
10561056

1057-
return fmt.Errorf("unsupported MFA provider: %v", authQuery.Provider)
1057+
return fmt.Errorf("unsupported MFA provider: %v", *authQuery.Provider)
10581058
}
10591059

10601060
func (context *ContextImpl) Dial(serviceName string) (edge.Conn, error) {

0 commit comments

Comments
 (0)