Skip to content

Commit 2c21a2f

Browse files
committed
refactor: address review concerns
1 parent 78891fa commit 2c21a2f

File tree

5 files changed

+83
-22
lines changed

5 files changed

+83
-22
lines changed

internal/proxy/balancer.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ func (rr *LatencyBased) Next() *Server {
108108
rr.mu.Lock()
109109
defer rr.mu.Unlock()
110110

111+
// Return nil if no servers are available
112+
if len(rr.servers) == 0 {
113+
return nil
114+
}
115+
111116
r := rr.randomizer.Float64()
112117
cumulative := 0.0
113118

internal/proxy/proxy.go

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ func (p *Proxy) Start(ctx context.Context, update Updater) {
133133

134134
// newRedirectFollowingReverseProxy creates a configured httputil.ReverseProxy that automatically follows 301 redirects.
135135
// It handles 301 redirects by following them automatically and returning the final response status and content.
136-
func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger) *httputil.ReverseProxy {
136+
func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger, proxyType string) *httputil.ReverseProxy {
137137
// Create a custom HTTP client that doesn't follow redirects automatically
138-
client := &http.Client{
138+
redirectClient := &http.Client{
139139
CheckRedirect: func(req *http.Request, via []*http.Request) error {
140140
// Stop automatic redirect following
141141
return http.ErrUseLastResponse
@@ -149,13 +149,13 @@ func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger) *httputil.R
149149
request.URL.Path = srv.Url.Path + request.URL.Path
150150
request.Host = srv.Url.Host
151151
},
152-
Transport: client.Transport,
152+
Transport: http.DefaultTransport,
153153
ModifyResponse: func(response *http.Response) error {
154154
cors.DeleteCorsHeaders(response)
155155

156-
// Handle 301 redirects by following them and returning content with final status
157-
if response.StatusCode == http.StatusMovedPermanently {
158-
metrics.IncrementRequestStatusCount("rest", srv.Url.String(), http.StatusMovedPermanently)
156+
// Handle redirect responses by following them and returning content with final status
157+
if isRedirect(response.StatusCode) {
158+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
159159

160160
location := response.Header.Get("Location")
161161
if location != "" {
@@ -164,20 +164,27 @@ func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger) *httputil.R
164164
return fmt.Errorf("failed to parse redirect location %q: %w", location, err)
165165
}
166166

167-
log.Info("following 301 redirect", "location", location, "resolved_url", redirectURL.String())
167+
log.Info("following redirect", "original_status", response.StatusCode, "location", location, "resolved_url", redirectURL.String())
168168

169-
redirectReq, err := http.NewRequest(response.Request.Method, redirectURL.String(), response.Request.Body)
169+
// Only follow redirects for safe/idempotent methods to avoid body consumption issues
170+
if !isIdempotentMethod(response.Request.Method) {
171+
log.Warn("skipping redirect for non-idempotent method", "method", response.Request.Method)
172+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
173+
return nil
174+
}
175+
176+
redirectReq, err := http.NewRequestWithContext(response.Request.Context(), response.Request.Method, redirectURL.String(), nil)
170177
if err != nil {
171178
return fmt.Errorf("failed to create redirect request to %q: %w", redirectURL.String(), err)
172179
}
173180

174-
for name, headers := range response.Request.Header {
175-
for _, h := range headers {
176-
redirectReq.Header.Add(name, h)
177-
}
181+
copyHeaders(response.Request.Header, redirectReq.Header)
182+
183+
if response.Body != nil {
184+
response.Body.Close()
178185
}
179186

180-
redirectResp, err := client.Do(redirectReq)
187+
redirectResp, err := redirectClient.Do(redirectReq)
181188
if err != nil {
182189
return fmt.Errorf("failed to follow redirect to %q: %w", redirectURL.String(), err)
183190
}
@@ -188,13 +195,16 @@ func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger) *httputil.R
188195
response.ContentLength = redirectResp.ContentLength
189196
response.Header = redirectResp.Header.Clone()
190197

191-
log.Info("successfully handled 301 redirect", "original_status", "301", "final_status", redirectResp.StatusCode)
192-
metrics.IncrementRequestStatusCount("rest", srv.Url.String(), response.StatusCode)
198+
cors.DeleteCorsHeaders(response)
199+
200+
log.Info("successfully handled redirect", "original_status", "redirect", "final_status", redirectResp.StatusCode)
201+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
193202
} else {
194-
log.Warn("301 redirect without Location header, serving as-is")
203+
log.Warn("redirect without Location header, serving as-is", "status", response.StatusCode)
204+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
195205
}
196206
} else {
197-
metrics.IncrementRequestStatusCount("rest", srv.Url.String(), response.StatusCode)
207+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
198208
}
199209

200210
return nil
@@ -205,3 +215,49 @@ func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger) *httputil.R
205215
},
206216
}
207217
}
218+
219+
// isRedirect checks if the status code represents a redirect
220+
func isRedirect(statusCode int) bool {
221+
switch statusCode {
222+
case http.StatusMovedPermanently,
223+
http.StatusFound,
224+
http.StatusTemporaryRedirect,
225+
http.StatusPermanentRedirect:
226+
return true
227+
default:
228+
return false
229+
}
230+
}
231+
232+
// isIdempotentMethod checks if the HTTP method is safe to replay without side effects
233+
func isIdempotentMethod(method string) bool {
234+
switch method {
235+
case http.MethodGet, http.MethodHead, http.MethodOptions:
236+
return true
237+
default:
238+
return false
239+
}
240+
}
241+
242+
// copyHeaders copies headers from src to dst, excluding hop-by-hop headers
243+
func copyHeaders(src, dst http.Header) {
244+
// Hop-by-hop headers that should not be forwarded
245+
hopByHopHeaders := map[string]bool{
246+
"Connection": true,
247+
"Keep-Alive": true,
248+
"Proxy-Authenticate": true,
249+
"Proxy-Authorization": true,
250+
"Te": true,
251+
"Trailers": true,
252+
"Transfer-Encoding": true,
253+
"Upgrade": true,
254+
}
255+
256+
for name, headers := range src {
257+
if !hopByHopHeaders[name] {
258+
for _, h := range headers {
259+
dst.Add(name, h)
260+
}
261+
}
262+
}
263+
}

internal/proxy/proxy_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ func TestNewReverseProxy(t *testing.T) {
191191
}
192192

193193
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
194-
proxy := newRedirectFollowingReverseProxy(srv, logger)
194+
proxy := newRedirectFollowingReverseProxy(srv, logger, "rest")
195195

196196
// Create test request
197197
req := httptest.NewRequest("GET", tt.reqPath, nil)
@@ -210,7 +210,7 @@ func TestReverseProxy_ModifyResponse(t *testing.T) {
210210
targetURL, _ := url.Parse("http://node.com")
211211
srv := &Server{Url: targetURL}
212212
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
213-
proxy := newRedirectFollowingReverseProxy(srv, logger)
213+
proxy := newRedirectFollowingReverseProxy(srv, logger, "rest")
214214

215215
resp := &http.Response{Header: make(http.Header)}
216216
resp.Header.Set("Access-Control-Allow-Origin", "*")
@@ -229,7 +229,7 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
229229
targetURL, _ := url.Parse("http://node.com")
230230
srv := &Server{Url: targetURL}
231231
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
232-
proxy := newRedirectFollowingReverseProxy(srv, logger)
232+
proxy := newRedirectFollowingReverseProxy(srv, logger, "rest")
233233

234234
w := httptest.NewRecorder()
235235
req := httptest.NewRequest("GET", "/test", nil)

internal/proxy/rest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (p *RestProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4848

4949
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/rest")
5050
if srv := p.lb.Next(); srv != nil {
51-
proxy := newRedirectFollowingReverseProxy(srv, p.log)
51+
proxy := newRedirectFollowingReverseProxy(srv, p.log, "rest")
5252
proxy.ServeHTTP(w, r)
5353
metrics.IncrementRequestCount("rest", srv.Url.String())
5454
return

internal/proxy/rpc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (p *RPCProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4848

4949
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/rpc")
5050
if srv := p.lb.Next(); srv != nil {
51-
proxy := newRedirectFollowingReverseProxy(srv, p.log)
51+
proxy := newRedirectFollowingReverseProxy(srv, p.log, "rpc")
5252
proxy.ServeHTTP(w, r)
5353
metrics.IncrementRequestCount("rpc", srv.Url.String())
5454
return

0 commit comments

Comments
 (0)