Skip to content

Commit e0dcefd

Browse files
authored
Fix deadlock on http.Server.Shutdown error (#40)
1 parent e6ee980 commit e0dcefd

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

server.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net"
88
"net/http"
9+
"sync"
910

1011
"github.com/int128/listener"
1112
"golang.org/x/sync/errgroup"
@@ -22,35 +23,32 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
2223
respCh := make(chan *authorizationResponse)
2324
server := http.Server{
2425
Handler: c.LocalServerMiddleware(&localServerHandler{
25-
config: c,
26-
responseCh: respCh,
26+
config: c,
27+
respCh: respCh,
2728
}),
2829
}
2930
var resp *authorizationResponse
3031
var eg errgroup.Group
3132
eg.Go(func() error {
32-
for {
33-
select {
34-
case received, ok := <-respCh:
35-
if !ok {
36-
c.Logf("oauth2cli: response channel has been closed")
37-
return nil // channel is closed (after the server is stopped)
38-
}
39-
if resp == nil {
40-
resp = received // pick only the first response
41-
}
42-
c.Logf("oauth2cli: shutting down the server at %s", l.Addr())
43-
if err := server.Shutdown(ctx); err != nil {
44-
return fmt.Errorf("could not shutdown the local server: %w", err)
45-
}
46-
case <-ctx.Done():
47-
c.Logf("oauth2cli: context cancelled: %s", ctx.Err())
48-
c.Logf("oauth2cli: shutting down the server at %s", l.Addr())
49-
if err := server.Shutdown(ctx); err != nil {
50-
return fmt.Errorf("could not shutdown the local server: %w", err)
51-
}
52-
return fmt.Errorf("context cancelled while waiting for authorization response: %w", ctx.Err())
33+
select {
34+
case gotResp, ok := <-respCh:
35+
if !ok {
36+
c.Logf("oauth2cli: response channel has been closed")
37+
return nil
5338
}
39+
resp = gotResp
40+
c.Logf("oauth2cli: shutting down the server at %s", l.Addr())
41+
if err := server.Shutdown(ctx); err != nil {
42+
return fmt.Errorf("could not shutdown the local server: %w", err)
43+
}
44+
return nil
45+
case <-ctx.Done():
46+
c.Logf("oauth2cli: context cancelled: %s", ctx.Err())
47+
c.Logf("oauth2cli: shutting down the server at %s", l.Addr())
48+
if err := server.Shutdown(ctx); err != nil {
49+
return fmt.Errorf("could not shutdown the local server: %w", err)
50+
}
51+
return fmt.Errorf("context cancelled while waiting for authorization response: %w", ctx.Err())
5452
}
5553
})
5654
eg.Go(func() error {
@@ -98,16 +96,21 @@ type authorizationResponse struct {
9896

9997
type localServerHandler struct {
10098
config *Config
101-
responseCh chan<- *authorizationResponse
99+
respCh chan<- *authorizationResponse // channel to send a response to
100+
onceRespCh sync.Once // ensure send once
102101
}
103102

104103
func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
105104
q := r.URL.Query()
106105
switch {
107106
case r.Method == "GET" && r.URL.Path == "/" && q.Get("error") != "":
108-
h.responseCh <- h.handleErrorResponse(w, r)
107+
h.onceRespCh.Do(func() {
108+
h.respCh <- h.handleErrorResponse(w, r)
109+
})
109110
case r.Method == "GET" && r.URL.Path == "/" && q.Get("code") != "":
110-
h.responseCh <- h.handleCodeResponse(w, r)
111+
h.onceRespCh.Do(func() {
112+
h.respCh <- h.handleCodeResponse(w, r)
113+
})
111114
case r.Method == "GET" && r.URL.Path == "/":
112115
h.handleIndex(w, r)
113116
default:

0 commit comments

Comments
 (0)