Skip to content

Commit a780509

Browse files
committed
Merge remote-tracking branch 'origin/master'
# Conflicts: # coverage/coverage.out
2 parents 84f4e3a + 27c26db commit a780509

File tree

8 files changed

+176
-19
lines changed

8 files changed

+176
-19
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ You can specify configuration options either via a config file (default: `config
202202
| `security.password_reset_max_rate` /<br> `WAKAPI_PASSWORD_RESET_MAX_RATE` | `5/1h` | Rate limiting config for password reset endpoint in format `<max_req>/<multiplier><unit>`, where `unit` is one of `s`, `m` or `h`. |
203203
| `security.oidc` | `[]` | List of OpenID Connect provider configurations (for details, see [wiki](https://github.com/muety/wakapi/wiki/OpenID-Connect-login-(SSO))) |
204204
| `security.oidc[0].name` /<br> `WAKAPI_OIDC_PROVIDERS_0_NAME` | - | Name / identifier for the OpenID Connect provider (e.g. `gitlab`) |
205+
| `security.oidc[0].display_name` /<br> `WAKAPI_OIDC_PROVIDERS_0_DISPLAY_NAME` | - | Optional "human-readable" display name for the provider presented to the user |
205206
| `security.oidc[0].client_id` /<br> `WAKAPI_OIDC_PROVIDERS_0_CLIENT_ID` | - | OAuth client name with this provider |
206207
| `security.oidc[0].client_secret` /<br> `WAKAPI_OIDC_PROVIDERS_0_CLIENT_SECRET` | - | OAuth client secret with this provider |
207208
| `security.oidc[0].endpoint` /<br> `WAKAPI_OIDC_PROVIDERS_0_ENDPOINT` | - | OpenID Connect provider API entrypoint (for [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html)) |

config/config.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/duke-git/lancet/v2/slice"
15+
"github.com/duke-git/lancet/v2/strutil"
1516

1617
"log/slog"
1718

@@ -207,6 +208,7 @@ type SMTPMailConfig struct {
207208
type oidcProviderConfig struct {
208209
// for environment variables format, see renameEnvVars() down below
209210
Name string `yaml:"name"`
211+
DisplayName string `yaml:"display_name"` // optional
210212
ClientID string `yaml:"client_id"`
211213
ClientSecret string `yaml:"client_secret"`
212214
Endpoint string `yaml:"endpoint"` // base url from which auto-discovery (.well-known/openid-configuration) can be found
@@ -228,6 +230,32 @@ type Config struct {
228230
Mail mailConfig
229231
}
230232

233+
func (c *oidcProviderConfig) String() string {
234+
if c.DisplayName != "" {
235+
return c.DisplayName
236+
}
237+
return strutil.Capitalize(c.Name)
238+
}
239+
240+
func (c *oidcProviderConfig) Validate() error {
241+
var namePattern = regexp.MustCompile("^[a-zA-Z0-9-]+$")
242+
var endpointPattern = regexp.MustCompile("^https?://")
243+
244+
if !namePattern.MatchString(c.Name) {
245+
return fmt.Errorf("invalid provider name '%s', must only contain alphanumeric characters or '-'", c.Name)
246+
}
247+
if c.ClientID == "" {
248+
return fmt.Errorf("provider '%s' is missing client id", c.Name)
249+
}
250+
if c.ClientSecret == "" {
251+
return fmt.Errorf("provider '%s' is missing client secret", c.Name)
252+
}
253+
if !endpointPattern.MatchString(c.Endpoint) {
254+
return fmt.Errorf("provider '%s' is missing endpoint", c.Name)
255+
}
256+
return nil
257+
}
258+
231259
func (c *Config) AppStartTimestamp() string {
232260
return fmt.Sprintf("%d", appStartTime.Unix())
233261
}
@@ -627,6 +655,11 @@ func Load(configFlag string, version string) *Config {
627655
if d, err := time.Parse(config.App.DateTimeFormat, config.App.DateTimeFormat); err != nil || !d.Equal(time.Date(2006, time.January, 2, 15, 4, 0, 0, d.Location())) {
628656
Log().Fatal("invalid datetime format", "format", config.App.DateTimeFormat)
629657
}
658+
for _, provider := range config.Security.OidcProviders {
659+
if err := provider.Validate(); err != nil {
660+
Log().Fatal("invalid oidc provider config", "provider", provider.Name, "error", err)
661+
}
662+
}
630663

631664
cronParser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
632665

config/config_test.go

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ func Test_Load_OidcProviders(t *testing.T) {
1919
defer oidcMock2.Shutdown()
2020

2121
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_NAME", "testprovider1")
22+
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_DISPLAY_NAME", "Test Provider 1")
2223
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_ID", oidcMock1.ClientID)
2324
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_SECRET", oidcMock1.ClientSecret)
2425
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_ENDPOINT", oidcMock1.Addr()+"/oidc")
@@ -32,19 +33,127 @@ func Test_Load_OidcProviders(t *testing.T) {
3233

3334
assert.Len(t, oidcCfg, 2)
3435
assert.Equal(t, "testprovider1", oidcCfg[0].Name)
36+
assert.Equal(t, "Test Provider 1", oidcCfg[0].DisplayName)
37+
assert.Equal(t, "Test Provider 1", oidcCfg[0].String())
3538
assert.Equal(t, oidcMock1.ClientID, oidcCfg[0].ClientID)
3639
assert.Equal(t, oidcMock1.ClientSecret, oidcCfg[0].ClientSecret)
3740
assert.Equal(t, oidcMock1.Addr()+"/oidc", oidcCfg[0].Endpoint)
3841
assert.Equal(t, "testprovider2", oidcCfg[1].Name)
42+
assert.Equal(t, "", oidcCfg[1].DisplayName)
43+
assert.Equal(t, "Testprovider2", oidcCfg[1].String())
3944
assert.Equal(t, oidcMock2.ClientID, oidcCfg[1].ClientID)
4045
assert.Equal(t, oidcMock2.ClientSecret, oidcCfg[1].ClientSecret)
4146
assert.Equal(t, oidcMock2.Addr()+"/oidc", oidcCfg[1].Endpoint)
4247

43-
_, err1 := GetOidcProvider("testprovider1")
44-
_, err2 := GetOidcProvider("testprovider2")
45-
48+
p1, err1 := GetOidcProvider("testprovider1")
4649
assert.Nil(t, err1)
50+
assert.Equal(t, "Test Provider 1", p1.DisplayName)
51+
52+
p2, err2 := GetOidcProvider("testprovider2")
4753
assert.Nil(t, err2)
54+
assert.Equal(t, "Testprovider2", p2.DisplayName)
55+
}
56+
57+
func TestOidcProviderConfig_Validate(t *testing.T) {
58+
// note: test cases were generated by ai
59+
testCases := []struct {
60+
name string
61+
config oidcProviderConfig
62+
err string
63+
}{
64+
{
65+
name: "valid",
66+
config: oidcProviderConfig{
67+
Name: "test-provider-1",
68+
ClientID: "client-id",
69+
ClientSecret: "client-secret",
70+
Endpoint: "https://provider.com/oidc",
71+
},
72+
err: "",
73+
},
74+
{
75+
name: "valid with http",
76+
config: oidcProviderConfig{
77+
Name: "test-provider-1",
78+
ClientID: "client-id",
79+
ClientSecret: "client-secret",
80+
Endpoint: "http://provider.com/oidc",
81+
},
82+
err: "",
83+
},
84+
{
85+
name: "invalid name with spaces",
86+
config: oidcProviderConfig{
87+
Name: "test provider",
88+
},
89+
err: "invalid provider name 'test provider', must only contain alphanumeric characters or '-'",
90+
},
91+
{
92+
name: "invalid name with underscore",
93+
config: oidcProviderConfig{
94+
Name: "test_provider",
95+
},
96+
err: "invalid provider name 'test_provider', must only contain alphanumeric characters or '-'",
97+
},
98+
{
99+
name: "missing client id",
100+
config: oidcProviderConfig{
101+
Name: "test-provider",
102+
ClientSecret: "client-secret",
103+
Endpoint: "https://provider.com/oidc",
104+
},
105+
err: "provider 'test-provider' is missing client id",
106+
},
107+
{
108+
name: "missing client secret",
109+
config: oidcProviderConfig{
110+
Name: "test-provider",
111+
ClientID: "client-id",
112+
Endpoint: "https://provider.com/oidc",
113+
},
114+
err: "provider 'test-provider' is missing client secret",
115+
},
116+
{
117+
name: "missing endpoint",
118+
config: oidcProviderConfig{
119+
Name: "test-provider",
120+
ClientID: "client-id",
121+
ClientSecret: "client-secret",
122+
},
123+
err: "provider 'test-provider' is missing endpoint",
124+
},
125+
{
126+
name: "invalid endpoint scheme",
127+
config: oidcProviderConfig{
128+
Name: "test-provider",
129+
ClientID: "client-id",
130+
ClientSecret: "client-secret",
131+
Endpoint: "ftp://provider.com/oidc",
132+
},
133+
err: "provider 'test-provider' is missing endpoint",
134+
},
135+
{
136+
name: "endpoint without scheme",
137+
config: oidcProviderConfig{
138+
Name: "test-provider",
139+
ClientID: "client-id",
140+
ClientSecret: "client-secret",
141+
Endpoint: "provider.com/oidc",
142+
},
143+
err: "provider 'test-provider' is missing endpoint",
144+
},
145+
}
146+
147+
for _, tc := range testCases {
148+
t.Run(tc.name, func(t *testing.T) {
149+
err := tc.config.Validate()
150+
if tc.err == "" {
151+
assert.NoError(t, err)
152+
} else {
153+
assert.EqualError(t, err, tc.err)
154+
}
155+
})
156+
}
48157
}
49158

50159
func TestConfig_IsDev(t *testing.T) {

config/oidc.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import (
1111
)
1212

1313
type OidcProvider struct {
14-
Name string
15-
OAuth2 *oauth2.Config
16-
Verifier *oidc.IDTokenVerifier
14+
Name string
15+
DisplayName string
16+
OAuth2 *oauth2.Config
17+
Verifier *oidc.IDTokenVerifier
1718
}
1819

1920
type IdTokenPayload struct {
@@ -70,9 +71,10 @@ func RegisterOidcProvider(providerCfg *oidcProviderConfig) {
7071
}
7172

7273
oidcProviders[providerCfg.Name] = &OidcProvider{
73-
Name: providerCfg.Name,
74-
OAuth2: &oauth2Conf,
75-
Verifier: provider.Verifier(&oidc.Config{ClientID: providerCfg.ClientID}),
74+
Name: providerCfg.Name,
75+
DisplayName: providerCfg.String(),
76+
OAuth2: &oauth2Conf,
77+
Verifier: provider.Verifier(&oidc.Config{ClientID: providerCfg.ClientID}),
7678
}
7779
}
7880

models/view/login.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ type LoginViewModel struct {
66
AllowSignup bool
77
CaptchaId string
88
InviteCode string
9-
OidcProviders []string
9+
OidcProviders []LoginViewModelOidcProvider
10+
}
11+
12+
type LoginViewModelOidcProvider struct {
13+
Name string
14+
DisplayName string
1015
}
1116

1217
type SetPasswordViewModel struct {

routes/login.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/dchest/captcha"
1212
"github.com/duke-git/lancet/v2/random"
13+
"github.com/duke-git/lancet/v2/slice"
1314
"github.com/go-chi/chi/v5"
1415
"github.com/go-chi/httprate"
1516
conf "github.com/muety/wakapi/config"
@@ -475,7 +476,13 @@ func (h *LoginHandler) buildViewModel(r *http.Request, w http.ResponseWriter, wi
475476
TotalUsers: int(numUsers),
476477
AllowSignup: h.config.IsDev() || h.config.Security.AllowSignup,
477478
InviteCode: r.URL.Query().Get("invite"),
478-
OidcProviders: h.config.Security.ListOidcProviders(),
479+
OidcProviders: slice.Map(h.config.Security.ListOidcProviders(), func(i int, providerName string) view.LoginViewModelOidcProvider {
480+
provider, _ := conf.GetOidcProvider(providerName) // no error, because only using registered provider names
481+
return view.LoginViewModelOidcProvider{
482+
Name: provider.Name,
483+
DisplayName: provider.DisplayName,
484+
}
485+
}),
479486
}
480487

481488
if withCaptcha {

views/login.tpl.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ <h1 class="h1">Welcome!</h1>
4949
<div class="mt-10">
5050
<div class="font-semibold mb-2">Single Sign-On</div>
5151
{{ range $provider := .OidcProviders }}
52-
<a href="oidc/{{ $provider | lower }}/login" class="block mb-2">
52+
<a href="oidc/{{ $provider.Name | lower }}/login" class="block mb-2">
5353
<button type="button" class="btn-default w-full flex items-center gap-2 justify-center">
54-
{{ if $.OidcProviderIcon $provider }}
55-
<span class="iconify inline text-white text-base" data-icon="{{ ($.OidcProviderIcon $provider) | urlSafe }}"></span>
54+
{{ if $.OidcProviderIcon $provider.Name }}
55+
<span class="iconify inline text-white text-base" data-icon="{{ ($.OidcProviderIcon $provider.Name) | urlSafe }}"></span>
5656
{{ end }}
57-
Login with {{ $provider | capitalize }}
57+
Login with {{ $provider.DisplayName }}
5858
</button>
5959
</a>
6060
{{ end }}

views/signup.tpl.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ <h1 class="h1">Sign up to Wakapi</h1>
114114
<div class="mt-10">
115115
<div class="font-semibold mb-2">Single Sign-On</div>
116116
{{ range $provider := .OidcProviders }}
117-
<a href="oidc/{{ $provider | lower }}/login" class="block mb-2">
117+
<a href="oidc/{{ $provider.Name | lower }}/login" class="block mb-2">
118118
<button type="button" class="btn-default w-full flex items-center gap-2 justify-center">
119-
{{ if $.OidcProviderIcon $provider }}
120-
<span class="iconify inline text-white text-base" data-icon="{{ ($.OidcProviderIcon $provider) | urlSafe }}"></span>
119+
{{ if $.OidcProviderIcon $provider.Name }}
120+
<span class="iconify inline text-white text-base" data-icon="{{ ($.OidcProviderIcon $provider.Name) | urlSafe }}"></span>
121121
{{ end }}
122-
Sign up {{ $provider | capitalize }}
122+
Sign up {{ $provider.DisplayName }}
123123
</button>
124124
</a>
125125
{{ end }}

0 commit comments

Comments
 (0)