diff --git a/apps/internal/local/server.go b/apps/internal/local/server.go index c6baf209..414f626a 100644 --- a/apps/internal/local/server.go +++ b/apps/internal/local/server.go @@ -5,6 +5,7 @@ package local import ( + "bytes" "context" "fmt" "html" @@ -28,7 +29,7 @@ var okPage = []byte(` `) -const failPage = ` +var failPage = []byte(` @@ -37,10 +38,19 @@ const failPage = `

Authentication failed. You can return to the application. Feel free to close this browser tab.

-

Error details: error %s error_description: %s

+

Error details: error {{.Code}}, error description: {{.Err}}

-` +`) + +var ( + // code is the html template variable name, + // which matches the Result Code variable + code = []byte("{{.Code}}") + // err is the html template variable name + // which matches the Result Err variable + err = []byte("{{.Err}}") +) // Result is the result from the redirect. type Result struct { @@ -53,14 +63,16 @@ type Result struct { // Server is an HTTP server. type Server struct { // Addr is the address the server is listening on. - Addr string - resultCh chan Result - s *http.Server - reqState string + Addr string + resultCh chan Result + s *http.Server + reqState string + successPage []byte + errorPage []byte } // New creates a local HTTP server and starts it. -func New(reqState string, port int) (*Server, error) { +func New(reqState string, port int, successPage []byte, errorPage []byte) (*Server, error) { var l net.Listener var err error var portStr string @@ -84,11 +96,21 @@ func New(reqState string, port int) (*Server, error) { return nil, err } + if len(successPage) == 0 { + successPage = okPage + } + + if len(errorPage) == 0 { + errorPage = failPage + } + serv := &Server{ - Addr: fmt.Sprintf("http://localhost:%s", portStr), - s: &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second}, - reqState: reqState, - resultCh: make(chan Result, 1), + Addr: fmt.Sprintf("http://localhost:%s", portStr), + s: &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second}, + reqState: reqState, + resultCh: make(chan Result, 1), + successPage: successPage, + errorPage: errorPage, } serv.s.Handler = http.HandlerFunc(serv.handler) @@ -142,12 +164,18 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { headerErr := q.Get("error") if headerErr != "" { - desc := html.EscapeString(q.Get("error_description")) - escapedHeaderErr := html.EscapeString(headerErr) // Note: It is a little weird we handle some errors by not going to the failPage. If they all should, // change this to s.error() and make s.error() write the failPage instead of an error code. - _, _ = w.Write([]byte(fmt.Sprintf(failPage, escapedHeaderErr, desc))) - s.putResult(Result{Err: fmt.Errorf("%s", desc)}) + + escapedErrDesc := html.EscapeString(q.Get("error_description")) // provides XSS protection + escapedHeaderErr := html.EscapeString(headerErr) // provides XSS protection + + errorPage := bytes.ReplaceAll(s.errorPage, code, []byte(escapedHeaderErr)) + errorPage = bytes.ReplaceAll(errorPage, err, []byte(escapedErrDesc)) + + _, _ = w.Write(errorPage) + + s.putResult(Result{Err: fmt.Errorf("%s", escapedErrDesc)}) return } @@ -169,7 +197,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { return } - _, _ = w.Write(okPage) + _, _ = w.Write(s.successPage) s.putResult(Result{Code: code}) } diff --git a/apps/internal/local/server_test.go b/apps/internal/local/server_test.go index 70af8b14..cafc223d 100644 --- a/apps/internal/local/server_test.go +++ b/apps/internal/local/server_test.go @@ -4,6 +4,7 @@ package local import ( + "bytes" "context" "io" "net/http" @@ -20,12 +21,18 @@ func TestServer(t *testing.T) { defer cancel() tests := []struct { - desc string - reqState string - port int - q url.Values - failPage bool - statusCode int + desc string + reqState string + port int + q url.Values + failPage bool + statusCode int + successPage []byte + errorPage []byte + testTemplate bool + testErrCodeXSS bool + testErrDescriptionXSS bool + expected string }{ { desc: "Error: Query Values has 'error' key", @@ -63,10 +70,99 @@ func TestServer(t *testing.T) { q: url.Values{"state": []string{"state"}, "code": []string{"code"}}, statusCode: 200, }, + { + desc: "Error: Query Values missing 'state' key, and optional error page, with template having code and error", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{"error_description"}}, + statusCode: 200, + errorPage: []byte("test option error page {{.Code}} {{.Err}}"), + testTemplate: true, + expected: "test option error page error_code error_description", + }, + { + desc: "Error: Query Values missing 'state' key, and optional error page, with template having only code", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{"error_description"}}, + statusCode: 200, + errorPage: []byte("test option error page {{.Code}}"), + testTemplate: true, + expected: "test option error page error_code", + }, + { + desc: "Error: Query Values missing 'state' key, and optional error page, with template having only error", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{"error_description"}}, + statusCode: 200, + errorPage: []byte("test option error page {{.Err}}"), + testTemplate: true, + expected: "test option error page error_description", + }, + { + desc: "Error: Query Values missing 'state' key, and optional error page, having no code or error", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{"error_description"}}, + statusCode: 200, + errorPage: []byte("test option error page"), + testTemplate: true, + expected: "test option error page", + }, + { + desc: "Error: Query Values missing 'state' key, using default fail error page", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{"error_description"}}, + statusCode: 200, + testTemplate: true, + expected: "

Error details: error error_code, error description: error_description

", + }, + { + desc: "Error: Query Values missing 'state' key, using default fail error page - Error Code XSS test", + reqState: "state", + port: 0, + q: url.Values{"error": []string{""}, "error_description": []string{"error_description"}}, + statusCode: 200, + testTemplate: true, + testErrCodeXSS: true, + }, + { + desc: "Error: Query Values missing 'state' key, using default fail error page - Error Description XSS test", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{""}}, + statusCode: 200, + testTemplate: true, + testErrDescriptionXSS: true, + }, + { + desc: "Error: Query Values missing 'state' key, using optional fail error page - Error Code XSS test", + reqState: "state", + port: 0, + q: url.Values{"error": []string{""}, "error_description": []string{"error_description"}}, + statusCode: 200, + errorPage: []byte("error: {{.Code}} error_description: {{.Err}}"), + testTemplate: true, + testErrCodeXSS: true, + expected: "<script>alert('this code snippet was executed')</script>", + }, + { + desc: "Error: Query Values missing 'state' key, using optional fail error page - Error Description XSS test", + reqState: "state", + port: 0, + q: url.Values{"error": []string{"error_code"}, "error_description": []string{""}}, + statusCode: 200, + errorPage: []byte("error: {{.Code}} error_description: {{.Err}}"), + testTemplate: true, + testErrDescriptionXSS: true, + expected: "<script>alert('this code snippet was executed')</script>", + }, } for _, test := range tests { - serv, err := New(test.reqState, test.port) + serv, err := New(test.reqState, test.port, test.successPage, test.errorPage) if err != nil { panic(err) } @@ -129,6 +225,20 @@ func TestServer(t *testing.T) { continue } + if len(test.successPage) > 0 { + if !bytes.Equal(content, test.successPage) { + t.Errorf("TestServer(%s): -want/+got:\ntest option error page", test.desc) + } + continue + } + + if test.testTemplate { + if !strings.Contains(string(content), test.expected) { + t.Errorf("TestServer(%s): -want:%s got:%s ", test.desc, test.expected, string(content)) + } + continue + } + if !strings.Contains(string(content), "Authentication Complete") { t.Errorf("TestServer(%s): got failed page, okay page", test.desc) } diff --git a/apps/public/public.go b/apps/public/public.go index 7beed261..6e498666 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -529,6 +529,8 @@ type interactiveAuthOptions struct { claims, domainHint, loginHint, redirectURI, tenantID string openURL func(url string) error authnScheme AuthenticationScheme + successPage []byte + errorPage []byte } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive @@ -536,6 +538,35 @@ type AcquireInteractiveOption interface { acquireInteractiveOption() } +// WithSystemBrowserOptions sets the optional success and error pages. +// The error page supports two optional html template variables {{.Code}} and {{.Err}}, +// which will be replaced with the corresponding error code, and descriptions. +func WithSystemBrowserOptions(successPage, errorPage []byte) interface { + AcquireInteractiveOption + options.CallOption +} { + return struct { + AcquireInteractiveOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *interactiveAuthOptions: + t.successPage = make([]byte, len(successPage)) + copy(t.successPage, successPage) + + t.errorPage = make([]byte, len(errorPage)) + copy(t.errorPage, errorPage) + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AcquireInteractiveOption @@ -678,7 +709,7 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, if o.authnScheme != nil { authParams.AuthnScheme = o.authnScheme } - res, err := pca.browserLogin(ctx, redirectURL, authParams, o.openURL) + res, err := pca.browserLogin(ctx, redirectURL, authParams, o.openURL, o.successPage, o.errorPage) if err != nil { return AuthResult{}, err } @@ -716,13 +747,13 @@ func parsePort(u *url.URL) (int, error) { } // browserLogin calls openURL and waits for a user to log in -func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params authority.AuthParams, openURL func(string) error) (interactiveAuthResult, error) { +func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params authority.AuthParams, openURL func(string) error, successPage []byte, errorPage []byte) (interactiveAuthResult, error) { // start local redirect server so login can call us back port, err := parsePort(redirectURI) if err != nil { return interactiveAuthResult{}, err } - srv, err := local.New(params.State, port) + srv, err := local.New(params.State, port, successPage, errorPage) if err != nil { return interactiveAuthResult{}, err }