@@ -3,11 +3,8 @@ package main
33import (
44 "bytes"
55 "context"
6- "crypto/rand"
7- "crypto/sha256"
86 "crypto/tls"
97 "crypto/x509"
10- "encoding/base64"
118 "encoding/json"
129 "errors"
1310 "fmt"
@@ -17,7 +14,7 @@ import (
1714 "net/http/httputil"
1815 "net/url"
1916 "os"
20- "strings "
17+ "slices "
2118 "time"
2219
2320 "github.com/coreos/go-oidc/v3/oidc"
3330)
3431
3532func init () {
36- codeVerifier = generateCodeVerifier ()
37- codeChallenge = generateCodeChallenge (codeVerifier )
33+ codeVerifier = oauth2 . GenerateVerifier ()
34+ codeChallenge = oauth2 . S256ChallengeFromVerifier (codeVerifier )
3835}
3936
4037type app struct {
@@ -43,8 +40,9 @@ type app struct {
4340 pkce bool
4441 redirectURI string
4542
46- verifier * oidc.IDTokenVerifier
47- provider * oidc.Provider
43+ verifier * oidc.IDTokenVerifier
44+ provider * oidc.Provider
45+ scopesSupported []string
4846
4947 // Does the provider use "offline_access" scope to request a refresh token
5048 // or does it use "access_type=offline" (e.g. Google)?
@@ -188,7 +186,9 @@ func cmd() *cobra.Command {
188186
189187 a .provider = provider
190188 a .verifier = provider .Verifier (& oidc.Config {ClientID : a .clientID })
189+ a .scopesSupported = s .ScopesSupported
191190
191+ http .Handle ("/static/" , http .StripPrefix ("/static/" , staticHandler ))
192192 http .HandleFunc ("/" , a .handleIndex )
193193 http .HandleFunc ("/login" , a .handleLogin )
194194 http .HandleFunc (u .Path , a .handleCallback )
@@ -226,7 +226,10 @@ func main() {
226226}
227227
228228func (a * app ) handleIndex (w http.ResponseWriter , r * http.Request ) {
229- renderIndex (w )
229+ renderIndex (w , indexPageData {
230+ ScopesSupported : a .scopesSupported ,
231+ LogoURI : dexLogoDataURI ,
232+ })
230233}
231234
232235func (a * app ) oauth2Config (scopes []string ) * oauth2.Config {
@@ -240,15 +243,19 @@ func (a *app) oauth2Config(scopes []string) *oauth2.Config {
240243}
241244
242245func (a * app ) handleLogin (w http.ResponseWriter , r * http.Request ) {
243- var scopes []string
244- if extraScopes := r .FormValue ("extra_scopes" ); extraScopes != "" {
245- scopes = strings .Split (extraScopes , " " )
246- }
247- var clients []string
248- if crossClients := r .FormValue ("cross_client" ); crossClients != "" {
249- clients = strings .Split (crossClients , " " )
246+ if err := r .ParseForm (); err != nil {
247+ http .Error (w , fmt .Sprintf ("failed to parse form: %v" , err ), http .StatusBadRequest )
248+ return
250249 }
250+
251+ // Only use scopes that are checked in the form
252+ scopes := r .Form ["extra_scopes" ]
253+
254+ clients := r .Form ["cross_client" ]
251255 for _ , client := range clients {
256+ if client == "" {
257+ continue
258+ }
252259 scopes = append (scopes , "audience:server:client_id:" + client )
253260 }
254261 connectorID := ""
@@ -257,7 +264,7 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
257264 }
258265
259266 authCodeURL := ""
260- scopes = append (scopes , "openid" , "profile" , "email" )
267+ scopes = uniqueStrings (scopes )
261268
262269 var authCodeOptions []oauth2.AuthCodeOption
263270
@@ -266,13 +273,28 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
266273 authCodeOptions = append (authCodeOptions , oauth2 .SetAuthURLParam ("code_challenge_method" , "S256" ))
267274 }
268275
269- a .oauth2Config (scopes )
270- if r .FormValue ("offline_access" ) == "yes" {
271- authCodeOptions = append (authCodeOptions , oauth2 .AccessTypeOffline )
276+ // Check if offline_access scope is present to determine offline access mode
277+ hasOfflineAccess := false
278+ for _ , scope := range scopes {
279+ if scope == "offline_access" {
280+ hasOfflineAccess = true
281+ break
282+ }
272283 }
273- if a .offlineAsScope {
274- scopes = append (scopes , "offline_access" )
284+
285+ if hasOfflineAccess && ! a .offlineAsScope {
286+ // Provider uses access_type=offline instead of offline_access scope
287+ authCodeOptions = append (authCodeOptions , oauth2 .AccessTypeOffline )
288+ // Remove offline_access from scopes as it's not supported
289+ filteredScopes := make ([]string , 0 , len (scopes ))
290+ for _ , scope := range scopes {
291+ if scope != "offline_access" {
292+ filteredScopes = append (filteredScopes , scope )
293+ }
294+ }
295+ scopes = filteredScopes
275296 }
297+
276298 authCodeURL = a .oauth2Config (scopes ).AuthCodeURL (exampleAppState , authCodeOptions ... )
277299
278300 // Parse the auth code URL and safely add connector_id parameter if provided
@@ -369,23 +391,17 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
369391 }
370392
371393 buff := new (bytes.Buffer )
372- if err := json .Indent (buff , [] byte ( claims ) , "" , " " ); err != nil {
394+ if err := json .Indent (buff , claims , "" , " " ); err != nil {
373395 http .Error (w , fmt .Sprintf ("error indenting ID token claims: %v" , err ), http .StatusInternalServerError )
374396 return
375397 }
376398
377- renderToken (w , a .redirectURI , rawIDToken , accessToken , token .RefreshToken , buff .String ())
399+ renderToken (w , r . Context (), a . provider , a .redirectURI , rawIDToken , accessToken , token .RefreshToken , buff .String ())
378400}
379401
380- func generateCodeVerifier () string {
381- bytes := make ([]byte , 64 ) // 86 symbols Base64URL
382- if _ , err := rand .Read (bytes ); err != nil {
383- log .Fatalf ("rand.Read error: %v" , err )
384- }
385- return base64 .RawURLEncoding .EncodeToString (bytes )
386- }
402+ func uniqueStrings (values []string ) []string {
403+ slices .Sort (values )
404+ values = slices .Compact (values )
387405
388- func generateCodeChallenge (verifier string ) string {
389- hash := sha256 .Sum256 ([]byte (verifier ))
390- return base64 .RawURLEncoding .EncodeToString (hash [:])
406+ return values
391407}
0 commit comments