diff --git a/internal/configuration/config.go b/internal/configuration/config.go index d4332bb..d4260fb 100644 --- a/internal/configuration/config.go +++ b/internal/configuration/config.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "golang.org/x/oauth2" "io" "io/ioutil" "net/http" @@ -30,6 +31,7 @@ var ( // Config holds app configuration type Config struct { + AuthStyle string `long:"auth-style" env:"AUTH_STYLE" default:"auto-detect" choice:"auto-detect" choice:"header" choice:"params" description:"Optionally choose the authentication style of the OAuth2 library."` LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"` LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"` @@ -133,6 +135,17 @@ func (c *Config) parseFlags(args []string) error { return nil } +func (c *Config) ParseAuthStyle() oauth2.AuthStyle { + switch c.AuthStyle { + case "header": + return oauth2.AuthStyleInHeader + case "params": + return oauth2.AuthStyleInParams + default: + return oauth2.AuthStyleAutoDetect + } +} + func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) { // Parse rules in the format "rule.." parts := strings.Split(option, ".") diff --git a/internal/configuration/config_test.go b/internal/configuration/config_test.go index ace68e5..ab40a44 100644 --- a/internal/configuration/config_test.go +++ b/internal/configuration/config_test.go @@ -1,6 +1,7 @@ package configuration import ( + "golang.org/x/oauth2" "os" "testing" "time" @@ -23,6 +24,7 @@ func TestConfigDefaults(t *testing.T) { assert.Equal("text", c.LogFormat) assert.Equal("", c.AuthHost) + assert.Equal("auto-detect", c.AuthStyle) assert.Len(c.CookieDomains, 0) assert.False(c.InsecureCookie) assert.Equal("_forward_auth", c.CookieName) @@ -37,6 +39,7 @@ func TestConfigDefaults(t *testing.T) { func TestConfigParseArgs(t *testing.T) { assert := assert.New(t) c, err := NewConfig([]string{ + "--auth-style=header", "--cookie-name=cookiename", "--csrf-cookie-name", "\"csrfcookiename\"", "--rule.1.action=allow", @@ -47,6 +50,7 @@ func TestConfigParseArgs(t *testing.T) { require.Nil(t, err) // Check normal flags + assert.Equal("header", c.AuthStyle) assert.Equal("cookiename", c.CookieName) assert.Equal("csrfcookiename", c.CSRFCookieName) @@ -118,6 +122,24 @@ func TestConfigParseIni(t *testing.T) { }, c.Rules) } +func TestConfigParseAuthStyle(t *testing.T) { + assert := assert.New(t) + + c1, err := NewConfig([]string{}) + assert.Nil(err) + assert.Equal(oauth2.AuthStyleAutoDetect, c1.ParseAuthStyle()) + + c2, err := NewConfig([]string{}) + c2.AuthStyle = "params" + assert.Nil(err) + assert.Equal(oauth2.AuthStyleInParams, c2.ParseAuthStyle()) + + c3, err := NewConfig([]string{}) + c3.AuthStyle = "header" + assert.Nil(err) + assert.Equal(oauth2.AuthStyleInHeader, c3.ParseAuthStyle()) +} + func TestConfigParseEnvironment(t *testing.T) { assert := assert.New(t) os.Setenv("COOKIE_NAME", "env_cookie_name") diff --git a/internal/handlers/server.go b/internal/handlers/server.go index d3f4cc1..07bc89f 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -290,12 +290,17 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { scope = []string{oidc.ScopeOpenID, "profile", "email", "groups"} } + endpoint := provider.Endpoint() oauth2Config := oauth2.Config{ ClientID: s.config.ClientID, ClientSecret: s.config.ClientSecret, RedirectURL: s.authenticator.ComposeRedirectURI(r), - Endpoint: provider.Endpoint(), - Scopes: scope, + Endpoint: oauth2.Endpoint{ + AuthStyle: s.config.ParseAuthStyle(), + AuthURL: endpoint.AuthURL, + TokenURL: endpoint.TokenURL, + }, + Scopes: scope, } // Exchange code for token