Skip to content

Commit a6300b0

Browse files
authored
Refactor production code (#233)
* Refactor code * Fix * Fix
1 parent 3f5303a commit a6300b0

File tree

2 files changed

+65
-56
lines changed

2 files changed

+65
-56
lines changed

oauth2cli.go

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@ type Config struct {
5252
// OAuth2 config.
5353
// RedirectURL will be automatically set to the local server.
5454
OAuth2Config oauth2.Config
55-
// Hostname of the redirect URL.
56-
// You can set this if your provider does not accept localhost.
57-
// Default to localhost.
58-
RedirectURLHostname string
55+
5956
// Options for an authorization request.
6057
// You can set oauth2.AccessTypeOffline and the PKCE options here.
6158
AuthCodeOptions []oauth2.AuthCodeOption
@@ -66,6 +63,11 @@ type Config struct {
6663
// Default to a string of random 32 bytes.
6764
State string
6865

66+
// Hostname of the redirect URL.
67+
// You can set this if your provider does not accept localhost.
68+
// Default to localhost.
69+
RedirectURLHostname string
70+
6971
// Candidates of hostname and port which the local server binds to.
7072
// You can set port number to 0 to allocate a free port.
7173
// If multiple addresses are given, it will try the ports in order.
@@ -98,37 +100,37 @@ type Config struct {
98100
Logf func(format string, args ...interface{})
99101
}
100102

101-
func (c *Config) isLocalServerHTTPS() bool {
102-
return c.LocalServerCertFile != "" && c.LocalServerKeyFile != ""
103+
func (cfg *Config) isLocalServerHTTPS() bool {
104+
return cfg.LocalServerCertFile != "" && cfg.LocalServerKeyFile != ""
103105
}
104106

105-
func (c *Config) validateAndSetDefaults() error {
106-
if (c.LocalServerCertFile != "" && c.LocalServerKeyFile == "") ||
107-
(c.LocalServerCertFile == "" && c.LocalServerKeyFile != "") {
107+
func (cfg *Config) validateAndSetDefaults() error {
108+
if (cfg.LocalServerCertFile != "" && cfg.LocalServerKeyFile == "") ||
109+
(cfg.LocalServerCertFile == "" && cfg.LocalServerKeyFile != "") {
108110
return fmt.Errorf("both LocalServerCertFile and LocalServerKeyFile must be set")
109111
}
110-
if c.RedirectURLHostname == "" {
111-
c.RedirectURLHostname = "localhost"
112+
if cfg.RedirectURLHostname == "" {
113+
cfg.RedirectURLHostname = "localhost"
112114
}
113-
if c.State == "" {
114-
s, err := oauth2params.NewState()
115+
if cfg.State == "" {
116+
state, err := oauth2params.NewState()
115117
if err != nil {
116118
return fmt.Errorf("could not generate a state parameter: %w", err)
117119
}
118-
c.State = s
120+
cfg.State = state
119121
}
120-
if c.LocalServerMiddleware == nil {
121-
c.LocalServerMiddleware = noopMiddleware
122+
if cfg.LocalServerMiddleware == nil {
123+
cfg.LocalServerMiddleware = noopMiddleware
122124
}
123-
if c.LocalServerSuccessHTML == "" {
124-
c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML
125+
if cfg.LocalServerSuccessHTML == "" {
126+
cfg.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML
125127
}
126-
if (c.SuccessRedirectURL != "" && c.FailureRedirectURL == "") ||
127-
(c.SuccessRedirectURL == "" && c.FailureRedirectURL != "") {
128+
if (cfg.SuccessRedirectURL != "" && cfg.FailureRedirectURL == "") ||
129+
(cfg.SuccessRedirectURL == "" && cfg.FailureRedirectURL != "") {
128130
return fmt.Errorf("when using success and failure redirect URLs, set both URLs")
129131
}
130-
if c.Logf == nil {
131-
c.Logf = func(string, ...interface{}) {}
132+
if cfg.Logf == nil {
133+
cfg.Logf = func(string, ...interface{}) {}
132134
}
133135
return nil
134136
}
@@ -144,16 +146,16 @@ func (c *Config) validateAndSetDefaults() error {
144146
// 4. Receive a code via an authorization response (HTTP redirect).
145147
// 5. Exchange the code and a token.
146148
// 6. Return the code.
147-
func GetToken(ctx context.Context, c Config) (*oauth2.Token, error) {
148-
if err := c.validateAndSetDefaults(); err != nil {
149+
func GetToken(ctx context.Context, cfg Config) (*oauth2.Token, error) {
150+
if err := cfg.validateAndSetDefaults(); err != nil {
149151
return nil, fmt.Errorf("invalid config: %w", err)
150152
}
151-
code, err := receiveCodeViaLocalServer(ctx, &c)
153+
code, err := receiveCodeViaLocalServer(ctx, &cfg)
152154
if err != nil {
153155
return nil, fmt.Errorf("authorization error: %w", err)
154156
}
155-
c.Logf("oauth2cli: exchanging the code and token")
156-
token, err := c.OAuth2Config.Exchange(ctx, code, c.TokenRequestOptions...)
157+
cfg.Logf("oauth2cli: exchanging the code and token")
158+
token, err := cfg.OAuth2Config.Exchange(ctx, code, cfg.TokenRequestOptions...)
157159
if err != nil {
158160
return nil, fmt.Errorf("could not exchange the code and token: %w", err)
159161
}

server.go

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,28 @@ import (
66
"fmt"
77
"net"
88
"net/http"
9+
"net/url"
910
"sync"
1011
"time"
1112

1213
"github.com/int128/listener"
1314
"golang.org/x/sync/errgroup"
1415
)
1516

16-
func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
17-
l, err := listener.New(c.LocalServerBindAddress)
17+
func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error) {
18+
localServerListener, err := listener.New(cfg.LocalServerBindAddress)
1819
if err != nil {
1920
return "", fmt.Errorf("could not start a local server: %w", err)
2021
}
21-
defer l.Close()
22-
c.OAuth2Config.RedirectURL = computeRedirectURL(l, c)
22+
defer localServerListener.Close()
23+
24+
localServerPort := localServerListener.Addr().(*net.TCPAddr).Port
25+
cfg.OAuth2Config.RedirectURL = constructRedirectURL(cfg, localServerPort)
2326

2427
respCh := make(chan *authorizationResponse)
2528
server := http.Server{
26-
Handler: c.LocalServerMiddleware(&localServerHandler{
27-
config: c,
29+
Handler: cfg.LocalServerMiddleware(&localServerHandler{
30+
config: cfg,
2831
respCh: respCh,
2932
}),
3033
}
@@ -33,15 +36,18 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
3336
var eg errgroup.Group
3437
eg.Go(func() error {
3538
defer close(respCh)
36-
c.Logf("oauth2cli: starting a server at %s", l.Addr())
37-
defer c.Logf("oauth2cli: stopped the server")
38-
if c.isLocalServerHTTPS() {
39-
if err := server.ServeTLS(l, c.LocalServerCertFile, c.LocalServerKeyFile); err != nil && err != http.ErrServerClosed {
39+
cfg.Logf("oauth2cli: starting a server at %s", localServerListener.Addr())
40+
defer cfg.Logf("oauth2cli: stopped the server")
41+
if cfg.isLocalServerHTTPS() {
42+
if err := server.ServeTLS(localServerListener, cfg.LocalServerCertFile, cfg.LocalServerKeyFile); err != nil {
43+
if errors.Is(err, http.ErrServerClosed) {
44+
return nil
45+
}
4046
return fmt.Errorf("could not start HTTPS server: %w", err)
4147
}
4248
return nil
4349
}
44-
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
50+
if err := server.Serve(localServerListener); err != nil && err != http.ErrServerClosed {
4551
return fmt.Errorf("could not start HTTP server: %w", err)
4652
}
4753
return nil
@@ -63,22 +69,22 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
6369
// Gracefully shutdown the server in the timeout.
6470
// If the server has not started, Shutdown returns nil and this returns immediately.
6571
// If Shutdown has failed, force-close the server.
66-
c.Logf("oauth2cli: shutting down the server")
72+
cfg.Logf("oauth2cli: shutting down the server")
6773
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
6874
defer cancel()
6975
if err := server.Shutdown(ctx); err != nil {
70-
c.Logf("oauth2cli: force-closing the server: shutdown failed: %s", err)
76+
cfg.Logf("oauth2cli: force-closing the server: shutdown failed: %s", err)
7177
_ = server.Close()
7278
return nil
7379
}
7480
return nil
7581
})
7682
eg.Go(func() error {
77-
if c.LocalServerReadyChan == nil {
83+
if cfg.LocalServerReadyChan == nil {
7884
return nil
7985
}
8086
select {
81-
case c.LocalServerReadyChan <- c.OAuth2Config.RedirectURL:
87+
case cfg.LocalServerReadyChan <- cfg.OAuth2Config.RedirectURL:
8288
return nil
8389
case <-ctx.Done():
8490
return ctx.Err()
@@ -93,12 +99,14 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
9399
return resp.code, resp.err
94100
}
95101

96-
func computeRedirectURL(l net.Listener, c *Config) string {
97-
hostPort := fmt.Sprintf("%s:%d", c.RedirectURLHostname, l.Addr().(*net.TCPAddr).Port)
98-
if c.LocalServerCertFile != "" {
99-
return "https://" + hostPort
102+
func constructRedirectURL(cfg *Config, port int) string {
103+
var redirect url.URL
104+
redirect.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
105+
redirect.Scheme = "http"
106+
if cfg.isLocalServerHTTPS() {
107+
redirect.Scheme = "https"
100108
}
101-
return "http://" + hostPort
109+
return redirect.String()
102110
}
103111

104112
type authorizationResponse struct {
@@ -133,7 +141,7 @@ func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
133141
func (h *localServerHandler) handleIndex(w http.ResponseWriter, r *http.Request) {
134142
authCodeURL := h.config.OAuth2Config.AuthCodeURL(h.config.State, h.config.AuthCodeOptions...)
135143
h.config.Logf("oauth2cli: sending redirect to %s", authCodeURL)
136-
http.Redirect(w, r, authCodeURL, 302)
144+
http.Redirect(w, r, authCodeURL, http.StatusFound)
137145
}
138146

139147
func (h *localServerHandler) handleCodeResponse(w http.ResponseWriter, r *http.Request) *authorizationResponse {
@@ -147,21 +155,20 @@ func (h *localServerHandler) handleCodeResponse(w http.ResponseWriter, r *http.R
147155

148156
if h.config.SuccessRedirectURL != "" {
149157
http.Redirect(w, r, h.config.SuccessRedirectURL, http.StatusFound)
150-
} else {
151-
w.Header().Add("Content-Type", "text/html")
152-
if _, err := fmt.Fprint(w, h.config.LocalServerSuccessHTML); err != nil {
153-
http.Error(w, "server error", 500)
154-
return &authorizationResponse{err: fmt.Errorf("write error: %w", err)}
155-
}
158+
return &authorizationResponse{code: code}
156159
}
157160

161+
w.Header().Add("Content-Type", "text/html")
162+
if _, err := fmt.Fprint(w, h.config.LocalServerSuccessHTML); err != nil {
163+
http.Error(w, "server error", http.StatusInternalServerError)
164+
return &authorizationResponse{err: fmt.Errorf("write error: %w", err)}
165+
}
158166
return &authorizationResponse{code: code}
159167
}
160168

161169
func (h *localServerHandler) handleErrorResponse(w http.ResponseWriter, r *http.Request) *authorizationResponse {
162170
q := r.URL.Query()
163171
errorCode, errorDescription := q.Get("error"), q.Get("error_description")
164-
165172
h.authorizationError(w, r)
166173
return &authorizationResponse{err: fmt.Errorf("authorization error from server: %s %s", errorCode, errorDescription)}
167174
}
@@ -170,6 +177,6 @@ func (h *localServerHandler) authorizationError(w http.ResponseWriter, r *http.R
170177
if h.config.FailureRedirectURL != "" {
171178
http.Redirect(w, r, h.config.FailureRedirectURL, http.StatusFound)
172179
} else {
173-
http.Error(w, "authorization error", 500)
180+
http.Error(w, "authorization error", http.StatusInternalServerError)
174181
}
175182
}

0 commit comments

Comments
 (0)