@@ -6,25 +6,28 @@ import (
66 "fmt"
77 "net"
88 "net/http"
9+ "net/url"
910 "sync"
1011 "time"
1112
1213 "github.com/int128/listener"
1314 "golang.org/x/sync/errgroup"
1415)
1516
16- func receiveCodeViaLocalServer (ctx context.Context , c * Config ) (string , error ) {
17- l , err := listener .New (c .LocalServerBindAddress )
17+ func receiveCodeViaLocalServer (ctx context.Context , cfg * Config ) (string , error ) {
18+ localServerListener , err := listener .New (cfg .LocalServerBindAddress )
1819 if err != nil {
1920 return "" , fmt .Errorf ("could not start a local server: %w" , err )
2021 }
21- defer l .Close ()
22- c .OAuth2Config .RedirectURL = computeRedirectURL (l , c )
22+ defer localServerListener .Close ()
23+
24+ localServerPort := localServerListener .Addr ().(* net.TCPAddr ).Port
25+ cfg .OAuth2Config .RedirectURL = constructRedirectURL (cfg , localServerPort )
2326
2427 respCh := make (chan * authorizationResponse )
2528 server := http.Server {
26- Handler : c .LocalServerMiddleware (& localServerHandler {
27- config : c ,
29+ Handler : cfg .LocalServerMiddleware (& localServerHandler {
30+ config : cfg ,
2831 respCh : respCh ,
2932 }),
3033 }
@@ -33,15 +36,18 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
3336 var eg errgroup.Group
3437 eg .Go (func () error {
3538 defer close (respCh )
36- c .Logf ("oauth2cli: starting a server at %s" , l .Addr ())
37- defer c .Logf ("oauth2cli: stopped the server" )
38- if c .isLocalServerHTTPS () {
39- if err := server .ServeTLS (l , c .LocalServerCertFile , c .LocalServerKeyFile ); err != nil && err != http .ErrServerClosed {
39+ cfg .Logf ("oauth2cli: starting a server at %s" , localServerListener .Addr ())
40+ defer cfg .Logf ("oauth2cli: stopped the server" )
41+ if cfg .isLocalServerHTTPS () {
42+ if err := server .ServeTLS (localServerListener , cfg .LocalServerCertFile , cfg .LocalServerKeyFile ); err != nil {
43+ if errors .Is (err , http .ErrServerClosed ) {
44+ return nil
45+ }
4046 return fmt .Errorf ("could not start HTTPS server: %w" , err )
4147 }
4248 return nil
4349 }
44- if err := server .Serve (l ); err != nil && err != http .ErrServerClosed {
50+ if err := server .Serve (localServerListener ); err != nil && err != http .ErrServerClosed {
4551 return fmt .Errorf ("could not start HTTP server: %w" , err )
4652 }
4753 return nil
@@ -63,22 +69,22 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
6369 // Gracefully shutdown the server in the timeout.
6470 // If the server has not started, Shutdown returns nil and this returns immediately.
6571 // If Shutdown has failed, force-close the server.
66- c .Logf ("oauth2cli: shutting down the server" )
72+ cfg .Logf ("oauth2cli: shutting down the server" )
6773 ctx , cancel := context .WithTimeout (ctx , 500 * time .Millisecond )
6874 defer cancel ()
6975 if err := server .Shutdown (ctx ); err != nil {
70- c .Logf ("oauth2cli: force-closing the server: shutdown failed: %s" , err )
76+ cfg .Logf ("oauth2cli: force-closing the server: shutdown failed: %s" , err )
7177 _ = server .Close ()
7278 return nil
7379 }
7480 return nil
7581 })
7682 eg .Go (func () error {
77- if c .LocalServerReadyChan == nil {
83+ if cfg .LocalServerReadyChan == nil {
7884 return nil
7985 }
8086 select {
81- case c .LocalServerReadyChan <- c .OAuth2Config .RedirectURL :
87+ case cfg .LocalServerReadyChan <- cfg .OAuth2Config .RedirectURL :
8288 return nil
8389 case <- ctx .Done ():
8490 return ctx .Err ()
@@ -93,12 +99,14 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
9399 return resp .code , resp .err
94100}
95101
96- func computeRedirectURL (l net.Listener , c * Config ) string {
97- hostPort := fmt .Sprintf ("%s:%d" , c .RedirectURLHostname , l .Addr ().(* net.TCPAddr ).Port )
98- if c .LocalServerCertFile != "" {
99- return "https://" + hostPort
102+ func constructRedirectURL (cfg * Config , port int ) string {
103+ var redirect url.URL
104+ redirect .Host = fmt .Sprintf ("%s:%d" , cfg .RedirectURLHostname , port )
105+ redirect .Scheme = "http"
106+ if cfg .isLocalServerHTTPS () {
107+ redirect .Scheme = "https"
100108 }
101- return "http://" + hostPort
109+ return redirect . String ()
102110}
103111
104112type authorizationResponse struct {
@@ -133,7 +141,7 @@ func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
133141func (h * localServerHandler ) handleIndex (w http.ResponseWriter , r * http.Request ) {
134142 authCodeURL := h .config .OAuth2Config .AuthCodeURL (h .config .State , h .config .AuthCodeOptions ... )
135143 h .config .Logf ("oauth2cli: sending redirect to %s" , authCodeURL )
136- http .Redirect (w , r , authCodeURL , 302 )
144+ http .Redirect (w , r , authCodeURL , http . StatusFound )
137145}
138146
139147func (h * localServerHandler ) handleCodeResponse (w http.ResponseWriter , r * http.Request ) * authorizationResponse {
@@ -147,21 +155,20 @@ func (h *localServerHandler) handleCodeResponse(w http.ResponseWriter, r *http.R
147155
148156 if h .config .SuccessRedirectURL != "" {
149157 http .Redirect (w , r , h .config .SuccessRedirectURL , http .StatusFound )
150- } else {
151- w .Header ().Add ("Content-Type" , "text/html" )
152- if _ , err := fmt .Fprint (w , h .config .LocalServerSuccessHTML ); err != nil {
153- http .Error (w , "server error" , 500 )
154- return & authorizationResponse {err : fmt .Errorf ("write error: %w" , err )}
155- }
158+ return & authorizationResponse {code : code }
156159 }
157160
161+ w .Header ().Add ("Content-Type" , "text/html" )
162+ if _ , err := fmt .Fprint (w , h .config .LocalServerSuccessHTML ); err != nil {
163+ http .Error (w , "server error" , http .StatusInternalServerError )
164+ return & authorizationResponse {err : fmt .Errorf ("write error: %w" , err )}
165+ }
158166 return & authorizationResponse {code : code }
159167}
160168
161169func (h * localServerHandler ) handleErrorResponse (w http.ResponseWriter , r * http.Request ) * authorizationResponse {
162170 q := r .URL .Query ()
163171 errorCode , errorDescription := q .Get ("error" ), q .Get ("error_description" )
164-
165172 h .authorizationError (w , r )
166173 return & authorizationResponse {err : fmt .Errorf ("authorization error from server: %s %s" , errorCode , errorDescription )}
167174}
@@ -170,6 +177,6 @@ func (h *localServerHandler) authorizationError(w http.ResponseWriter, r *http.R
170177 if h .config .FailureRedirectURL != "" {
171178 http .Redirect (w , r , h .config .FailureRedirectURL , http .StatusFound )
172179 } else {
173- http .Error (w , "authorization error" , 500 )
180+ http .Error (w , "authorization error" , http . StatusInternalServerError )
174181 }
175182}
0 commit comments