Skip to content

Commit 8d9790b

Browse files
more specific
1 parent c9feca8 commit 8d9790b

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

services/auth/source/oauth2/providers.go

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ var azureProviders = []string{
8282
"azureadv2",
8383
}
8484

85-
func isAzureProvider(providerName string) bool {
86-
return slices.Contains(azureProviders, providerName)
87-
}
88-
8985
// RegisterGothProvider registers a GothProvider
9086
func RegisterGothProvider(provider GothProvider) {
9187
if _, has := gothProviders[provider.Name()]; has {
@@ -94,23 +90,24 @@ func RegisterGothProvider(provider GothProvider) {
9490
gothProviders[provider.Name()] = provider
9591
}
9692

97-
// hasExistingAzureADAuthSources checks if there are any existing Azure AD auth sources configured
98-
func hasExistingAzureADAuthSources(ctx context.Context) bool {
93+
// getExistingAzureADAuthSources returns a list of Azure AD provider names that are already configured
94+
func getExistingAzureADAuthSources(ctx context.Context) []string {
9995
authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
10096
LoginType: auth.OAuth2,
10197
})
10298
if err != nil {
103-
return false
99+
return nil
104100
}
105101

102+
var existingAzureProviders []string
106103
for _, source := range authSources {
107104
if oauth2Cfg, ok := source.Cfg.(*Source); ok {
108-
if isAzureProvider(oauth2Cfg.Provider) {
109-
return true
105+
if slices.Contains(azureProviders, oauth2Cfg.Provider) {
106+
existingAzureProviders = append(existingAzureProviders, oauth2Cfg.Provider)
110107
}
111108
}
112109
}
113-
return false
110+
return existingAzureProviders
114111
}
115112

116113
// GetSupportedOAuth2Providers returns the map of unconfigured OAuth2 providers
@@ -125,10 +122,10 @@ func GetSupportedOAuth2Providers() []Provider {
125122
// GetSupportedOAuth2ProvidersWithContext returns the list of supported OAuth2 providers with context for filtering
126123
func GetSupportedOAuth2ProvidersWithContext(ctx context.Context) []Provider {
127124
providers := make([]Provider, 0, len(gothProviders))
128-
hasExistingAzure := hasExistingAzureADAuthSources(ctx)
125+
existAuthSources := getExistingAzureADAuthSources(ctx)
129126

130127
for _, provider := range gothProviders {
131-
if isAzureProvider(provider.Name()) && !hasExistingAzure {
128+
if slices.Contains(azureProviders, provider.Name()) && !slices.Contains(existAuthSources, provider.Name()) {
132129
continue
133130
}
134131
providers = append(providers, provider)

0 commit comments

Comments
 (0)