Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions internal/configuration/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"golang.org/x/oauth2"
"io"
"io/ioutil"
"net/http"
Expand All @@ -30,6 +31,7 @@ var (

// Config holds app configuration
type Config struct {
AuthStyle string `long:"auth-style" env:"AUTH_STYLE" default:"auto-detect" choice:"auto-detect" choice:"header" choice:"params" description:"Optionally choose the authentication style of the OAuth2 library."`
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`

Expand Down Expand Up @@ -133,6 +135,17 @@ func (c *Config) parseFlags(args []string) error {
return nil
}

func (c *Config) ParseAuthStyle() oauth2.AuthStyle {
switch c.AuthStyle {
case "header":
return oauth2.AuthStyleInHeader
case "params":
return oauth2.AuthStyleInParams
default:
return oauth2.AuthStyleAutoDetect
}
}

func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) {
// Parse rules in the format "rule.<name>.<param>"
parts := strings.Split(option, ".")
Expand Down
22 changes: 22 additions & 0 deletions internal/configuration/config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package configuration

import (
"golang.org/x/oauth2"
"os"
"testing"
"time"
Expand All @@ -23,6 +24,7 @@ func TestConfigDefaults(t *testing.T) {
assert.Equal("text", c.LogFormat)

assert.Equal("", c.AuthHost)
assert.Equal("auto-detect", c.AuthStyle)
assert.Len(c.CookieDomains, 0)
assert.False(c.InsecureCookie)
assert.Equal("_forward_auth", c.CookieName)
Expand All @@ -37,6 +39,7 @@ func TestConfigDefaults(t *testing.T) {
func TestConfigParseArgs(t *testing.T) {
assert := assert.New(t)
c, err := NewConfig([]string{
"--auth-style=header",
"--cookie-name=cookiename",
"--csrf-cookie-name", "\"csrfcookiename\"",
"--rule.1.action=allow",
Expand All @@ -47,6 +50,7 @@ func TestConfigParseArgs(t *testing.T) {
require.Nil(t, err)

// Check normal flags
assert.Equal("header", c.AuthStyle)
assert.Equal("cookiename", c.CookieName)
assert.Equal("csrfcookiename", c.CSRFCookieName)

Expand Down Expand Up @@ -118,6 +122,24 @@ func TestConfigParseIni(t *testing.T) {
}, c.Rules)
}

func TestConfigParseAuthStyle(t *testing.T) {
assert := assert.New(t)

c1, err := NewConfig([]string{})
assert.Nil(err)
assert.Equal(oauth2.AuthStyleAutoDetect, c1.ParseAuthStyle())

c2, err := NewConfig([]string{})
c2.AuthStyle = "params"
assert.Nil(err)
assert.Equal(oauth2.AuthStyleInParams, c2.ParseAuthStyle())

c3, err := NewConfig([]string{})
c3.AuthStyle = "header"
assert.Nil(err)
assert.Equal(oauth2.AuthStyleInHeader, c3.ParseAuthStyle())
}

func TestConfigParseEnvironment(t *testing.T) {
assert := assert.New(t)
os.Setenv("COOKIE_NAME", "env_cookie_name")
Expand Down
9 changes: 7 additions & 2 deletions internal/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,17 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
scope = []string{oidc.ScopeOpenID, "profile", "email", "groups"}
}

endpoint := provider.Endpoint()
oauth2Config := oauth2.Config{
ClientID: s.config.ClientID,
ClientSecret: s.config.ClientSecret,
RedirectURL: s.authenticator.ComposeRedirectURI(r),
Endpoint: provider.Endpoint(),
Scopes: scope,
Endpoint: oauth2.Endpoint{
AuthStyle: s.config.ParseAuthStyle(),
AuthURL: endpoint.AuthURL,
TokenURL: endpoint.TokenURL,
},
Scopes: scope,
}

// Exchange code for token
Expand Down