diff --git a/oauthproxy.go b/oauthproxy.go index 21e5dfc74..c1821c778 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -469,7 +469,7 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { case path == p.SignOutPath: p.SignOut(rw, req) case path == p.OAuthStartPath: - p.OAuthStart(rw, req) + p.OAuthStart(rw, req, "") case path == p.OAuthCallbackPath: p.OAuthCallback(rw, req) case path == p.AuthOnlyPath: @@ -493,7 +493,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { http.Redirect(rw, req, redirect, 302) } else { if p.SkipProviderButton { - p.OAuthStart(rw, req) + p.OAuthStart(rw, req, "") } else { p.SignInPage(rw, req, http.StatusOK) } @@ -505,17 +505,20 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { http.Redirect(rw, req, "/", 302) } -func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request, redirect string) { nonce, err := cookie.Nonce() if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } p.SetCSRFCookie(rw, req, nonce) - redirect, err := p.GetRedirect(req) - if err != nil { - p.ErrorPage(rw, 500, "Internal Error", err.Error()) - return + // If not explicitly told where to redirect, try to get it from form parameters. + if redirect == "" { + redirect, err = p.GetRedirect(req) + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) + return + } } redirectURI := p.GetRedirectURI(req.Host) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) @@ -598,7 +601,8 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { "Internal Error", "Internal Error") } else if status == http.StatusForbidden { if p.SkipProviderButton { - p.OAuthStart(rw, req) + // Start OAuth but redirect back to this URI when complete + p.OAuthStart(rw, req, req.URL.RequestURI()) } else { p.SignInPage(rw, req, http.StatusForbidden) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 1e6b3140d..31890e03a 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -440,11 +440,19 @@ func TestSignInPageSkipProvider(t *testing.T) { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) } + + // State is a hex nonce, a colon (encoded as %3A), plus an escaped redirect path. + source_page_re := regexp.MustCompile("state=[0-9a-f]+%3A" + url.PathEscape(endpoint) + `[;"]`) + source_page_match := source_page_re.FindStringSubmatch(body) + if source_page_match == nil { + t.Fatal("Callback state should include redirect to original endpoint: " + + source_page_re.String() + "\nBody:\n" + body) + } } func TestSignInPageSkipProviderDirect(t *testing.T) { sip_test := NewSignInPageTest(true) - const endpoint = "/sign_in" + const endpoint = "/oauth2/sign_in" code, body := sip_test.GetEndpoint(endpoint) assert.Equal(t, 302, code) @@ -454,6 +462,14 @@ func TestSignInPageSkipProviderDirect(t *testing.T) { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) } + + // State is a hex nonce, a colon (encoded as %3A), plus an escaped redirect path. + source_page_re := regexp.MustCompile(`state=[0-9a-f]+%3A%2F[;"]`) + source_page_match := source_page_re.FindStringSubmatch(body) + if source_page_match == nil { + t.Fatal("Callback state should include redirect to /: " + + source_page_re.String() + "\nBody:\n" + body) + } } type ProcessCookieTest struct {