Skip to content

Commit 2847ed3

Browse files
committed
some improvements
1 parent 11d68dd commit 2847ed3

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

services/auth/source/oauth2/providers.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,8 @@ func (p *AuthSourceProvider) IconHTML(size int) template.HTML {
7676
// value is used to store display data
7777
var gothProviders = map[string]GothProvider{}
7878

79-
var azureProviders = []string{
80-
"azuread",
81-
"microsoftonline",
82-
"azureadv2",
79+
func isAzureProvider(name string) bool {
80+
return name == "azuread" || name == "microsoftonline" || name == "azureadv2"
8381
}
8482

8583
// RegisterGothProvider registers a GothProvider
@@ -91,23 +89,23 @@ func RegisterGothProvider(provider GothProvider) {
9189
}
9290

9391
// getExistingAzureADAuthSources returns a list of Azure AD provider names that are already configured
94-
func getExistingAzureADAuthSources(ctx context.Context) []string {
92+
func getExistingAzureADAuthSources(ctx context.Context) ([]string, error) {
9593
authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
9694
LoginType: auth.OAuth2,
9795
})
9896
if err != nil {
99-
return nil
97+
return nil, err
10098
}
10199

102100
var existingAzureProviders []string
103101
for _, source := range authSources {
104102
if oauth2Cfg, ok := source.Cfg.(*Source); ok {
105-
if slices.Contains(azureProviders, oauth2Cfg.Provider) {
103+
if isAzureProvider(oauth2Cfg.Provider) {
106104
existingAzureProviders = append(existingAzureProviders, oauth2Cfg.Provider)
107105
}
108106
}
109107
}
110-
return existingAzureProviders
108+
return existingAzureProviders, nil
111109
}
112110

113111
// GetSupportedOAuth2Providers returns the map of unconfigured OAuth2 providers
@@ -122,10 +120,13 @@ func GetSupportedOAuth2Providers() []Provider {
122120
// GetSupportedOAuth2ProvidersWithContext returns the list of supported OAuth2 providers with context for filtering
123121
func GetSupportedOAuth2ProvidersWithContext(ctx context.Context) []Provider {
124122
providers := make([]Provider, 0, len(gothProviders))
125-
existAuthSources := getExistingAzureADAuthSources(ctx)
123+
existingAzureSources, err := getExistingAzureADAuthSources(ctx)
124+
if err != nil {
125+
log.Error("Failed to get existing OAuth2 auth sources: %v", err)
126+
}
126127

127128
for _, provider := range gothProviders {
128-
if slices.Contains(azureProviders, provider.Name()) && !slices.Contains(existAuthSources, provider.Name()) {
129+
if isAzureProvider(provider.Name()) && !slices.Contains(existingAzureSources, provider.Name()) {
129130
continue
130131
}
131132
providers = append(providers, provider)

0 commit comments

Comments
 (0)