Skip to content

Commit 8361797

Browse files
committed
fix: change reverse proxy implementation to follow redirects
chore: improve log refactor: address review concerns refactor: metric
1 parent 707e5cf commit 8361797

File tree

6 files changed

+125
-13
lines changed

6 files changed

+125
-13
lines changed

internal/proxy/balancer.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ func (rr *LatencyBased) Next() *Server {
108108
rr.mu.Lock()
109109
defer rr.mu.Unlock()
110110

111+
// Return nil if no servers are available
112+
if len(rr.servers) == 0 {
113+
return nil
114+
}
115+
111116
r := rr.randomizer.Float64()
112117
cumulative := 0.0
113118

internal/proxy/grpc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func (p *GRPCProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
7171
p.log.Info("serving request", "target", srv.Url, "source", r.URL)
7272

7373
proxy.ServeHTTP(w, r)
74-
metrics.IncrementRequestCount("grpc", srv.Url.Host)
74+
metrics.IncrementRequestCount("grpc", srv.Url.String())
7575
return
7676
}
7777

internal/proxy/proxy.go

Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,25 +131,132 @@ func (p *Proxy) Start(ctx context.Context, update Updater) {
131131
})
132132
}
133133

134-
// newReverseProxy creates a configured httputil.ReverseProxy with common settings.
135-
func newReverseProxy(srv *Server, log *slog.Logger) *httputil.ReverseProxy {
134+
// newRedirectFollowingReverseProxy creates a configured httputil.ReverseProxy that automatically follows 301 redirects.
135+
// It handles 301 redirects by following them automatically and returning the final response status and content.
136+
func newRedirectFollowingReverseProxy(srv *Server, log *slog.Logger, proxyType string) *httputil.ReverseProxy {
137+
// Create a custom HTTP client that doesn't follow redirects automatically
138+
redirectClient := &http.Client{
139+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
140+
// Stop automatic redirect following
141+
return http.ErrUseLastResponse
142+
},
143+
}
144+
136145
return &httputil.ReverseProxy{
137146
Director: func(request *http.Request) {
138147
request.URL.Scheme = srv.Url.Scheme
139148
request.URL.Host = srv.Url.Host
140149
request.URL.Path = srv.Url.Path + request.URL.Path
141150
request.Host = srv.Url.Host
142-
143-
log.Info("proxying request", "method", request.Method, "target", request.URL, "source", request.URL)
144151
},
152+
Transport: http.DefaultTransport,
145153
ModifyResponse: func(response *http.Response) error {
146154
cors.DeleteCorsHeaders(response)
147-
metrics.IncrementRequestStatusCount("rpc", srv.Url.String(), response.StatusCode)
155+
156+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
157+
158+
// Handle redirect responses by following them and returning content with final status
159+
if isRedirect(response.StatusCode) {
160+
161+
location := response.Header.Get("Location")
162+
if location != "" {
163+
redirectURL, err := response.Request.URL.Parse(location)
164+
if err != nil {
165+
return fmt.Errorf("failed to parse redirect location %q: %w", location, err)
166+
}
167+
168+
log.Info("following redirect", "original_status", response.StatusCode, "location", location, "resolved_url", redirectURL.String())
169+
170+
// Only follow redirects for safe/idempotent methods to avoid body consumption issues
171+
if !isIdempotentMethod(response.Request.Method) {
172+
log.Warn("skipping redirect for non-idempotent method", "method", response.Request.Method)
173+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
174+
return nil
175+
}
176+
177+
redirectReq, err := http.NewRequestWithContext(response.Request.Context(), response.Request.Method, redirectURL.String(), nil)
178+
if err != nil {
179+
return fmt.Errorf("failed to create redirect request to %q: %w", redirectURL.String(), err)
180+
}
181+
182+
copyHeaders(response.Request.Header, redirectReq.Header)
183+
184+
if response.Body != nil {
185+
response.Body.Close()
186+
}
187+
188+
redirectResp, err := redirectClient.Do(redirectReq)
189+
if err != nil {
190+
return fmt.Errorf("failed to follow redirect to %q: %w", redirectURL.String(), err)
191+
}
192+
193+
response.StatusCode = redirectResp.StatusCode
194+
response.Status = redirectResp.Status
195+
response.Body = redirectResp.Body
196+
response.ContentLength = redirectResp.ContentLength
197+
response.Header = redirectResp.Header.Clone()
198+
199+
cors.DeleteCorsHeaders(response)
200+
201+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
202+
} else {
203+
log.Warn("redirect without Location header, serving as-is", "status", response.StatusCode)
204+
metrics.IncrementRequestStatusCount(proxyType, srv.Url.String(), response.StatusCode)
205+
}
206+
}
207+
148208
return nil
149209
},
150210
ErrorHandler: func(writer http.ResponseWriter, request *http.Request, err error) {
151-
log.Error("proxy error", "error", err)
211+
log.Error("reverse proxy error", "error", err)
152212
http.Error(writer, "could not proxy request", http.StatusInternalServerError)
153213
},
154214
}
155215
}
216+
217+
// isRedirect checks if the status code represents a redirect
218+
func isRedirect(statusCode int) bool {
219+
switch statusCode {
220+
case http.StatusMovedPermanently,
221+
http.StatusFound,
222+
http.StatusTemporaryRedirect,
223+
http.StatusPermanentRedirect:
224+
return true
225+
default:
226+
return false
227+
}
228+
}
229+
230+
// isIdempotentMethod checks if the HTTP method is safe to replay without side effects
231+
func isIdempotentMethod(method string) bool {
232+
switch method {
233+
case http.MethodGet, http.MethodHead, http.MethodOptions:
234+
return true
235+
default:
236+
return false
237+
}
238+
}
239+
240+
// copyHeaders copies headers from src to dst, excluding hop-by-hop headers
241+
func copyHeaders(src, dst http.Header) {
242+
// Hop-by-hop headers that should not be forwarded
243+
hopByHopHeaders := map[string]bool{
244+
"Connection": true,
245+
"Keep-Alive": true,
246+
"Proxy-Authenticate": true,
247+
"Proxy-Authorization": true,
248+
"Te": true,
249+
"Trailer": true,
250+
"Transfer-Encoding": true,
251+
"Upgrade": true,
252+
"Proxy-Connection": true,
253+
}
254+
255+
for name, headers := range src {
256+
if !hopByHopHeaders[name] {
257+
for _, h := range headers {
258+
dst.Add(name, h)
259+
}
260+
}
261+
}
262+
}

internal/proxy/proxy_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ func TestNewReverseProxy(t *testing.T) {
191191
}
192192

193193
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
194-
proxy := newReverseProxy(srv, logger)
194+
proxy := newRedirectFollowingReverseProxy(srv, logger, "rest")
195195

196196
// Create test request
197197
req := httptest.NewRequest("GET", tt.reqPath, nil)
@@ -210,7 +210,7 @@ func TestReverseProxy_ModifyResponse(t *testing.T) {
210210
targetURL, _ := url.Parse("http://node.com")
211211
srv := &Server{Url: targetURL}
212212
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
213-
proxy := newReverseProxy(srv, logger)
213+
proxy := newRedirectFollowingReverseProxy(srv, logger, "rest")
214214

215215
resp := &http.Response{Header: make(http.Header)}
216216
resp.Header.Set("Access-Control-Allow-Origin", "*")
@@ -229,7 +229,7 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
229229
targetURL, _ := url.Parse("http://node.com")
230230
srv := &Server{Url: targetURL}
231231
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
232-
proxy := newReverseProxy(srv, logger)
232+
proxy := newRedirectFollowingReverseProxy(srv, logger, "rest")
233233

234234
w := httptest.NewRecorder()
235235
req := httptest.NewRequest("GET", "/test", nil)

internal/proxy/rest.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ func (p *RestProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4848

4949
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/rest")
5050
if srv := p.lb.Next(); srv != nil {
51-
proxy := newReverseProxy(srv, p.log)
51+
proxy := newRedirectFollowingReverseProxy(srv, p.log, "rest")
5252
proxy.ServeHTTP(w, r)
53-
metrics.IncrementRequestCount("rest", srv.Url.Host)
53+
metrics.IncrementRequestCount("rest", srv.Url.String())
5454
return
5555
}
5656

internal/proxy/rpc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (p *RPCProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4848

4949
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/rpc")
5050
if srv := p.lb.Next(); srv != nil {
51-
proxy := newReverseProxy(srv, p.log)
51+
proxy := newRedirectFollowingReverseProxy(srv, p.log, "rpc")
5252
proxy.ServeHTTP(w, r)
5353
metrics.IncrementRequestCount("rpc", srv.Url.String())
5454
return

0 commit comments

Comments
 (0)