@@ -29,6 +29,33 @@ type (
2929 // Required.
3030 Balancer ProxyBalancer
3131
32+ // RetryCount defines the number of times a failed proxied request should be retried
33+ // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
34+ RetryCount int
35+
36+ // RetryFilter defines a function used to determine if a failed request to a
37+ // ProxyTarget should be retried. The RetryFilter will only be called when the number
38+ // of previous retries is less than RetryCount. If the function returns true, the
39+ // request will be retried. The provided error indicates the reason for the request
40+ // failure. When the ProxyTarget is unavailable, the error will be an instance of
41+ // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
42+ // will indicate an internal error in the Proxy middleware. When a RetryFilter is not
43+ // specified, all requests that fail with http.StatusBadGateway will be retried. A custom
44+ // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
45+ // only called when the request to the target fails, or an internal error in the Proxy
46+ // middleware has occurred. Successful requests that return a non-200 response code cannot
47+ // be retried.
48+ RetryFilter func (c echo.Context , e error ) bool
49+
50+ // ErrorHandler defines a function which can be used to return custom errors from
51+ // the Proxy middleware. ErrorHandler is only invoked when there has been
52+ // either an internal error in the Proxy middleware or the ProxyTarget is
53+ // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
54+ // when a ProxyTarget returns a non-200 response. In these cases, the response
55+ // is already written so errors cannot be modified. ErrorHandler is only
56+ // invoked after all retry attempts have been exhausted.
57+ ErrorHandler func (c echo.Context , err error ) error
58+
3259 // Rewrite defines URL path rewrite rules. The values captured in asterisk can be
3360 // retrieved by index e.g. $1, $2 and so on.
3461 // Examples:
7198 Next (echo.Context ) * ProxyTarget
7299 }
73100
74- // TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target.
101+ // TargetProvider defines an interface that gives the opportunity for balancer
102+ // to return custom errors when selecting target.
75103 TargetProvider interface {
76104 NextTarget (echo.Context ) (* ProxyTarget , error )
77105 }
@@ -107,22 +135,22 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
107135 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
108136 in , _ , err := c .Response ().Hijack ()
109137 if err != nil {
110- c .Set ("_error" , fmt .Sprintf ("proxy raw, hijack error=%v , url=%s" , t .URL , err ))
138+ c .Set ("_error" , fmt .Errorf ("proxy raw, hijack error=%w , url=%s" , err , t .URL ))
111139 return
112140 }
113141 defer in .Close ()
114142
115143 out , err := net .Dial ("tcp" , t .URL .Host )
116144 if err != nil {
117- c .Set ("_error" , echo .NewHTTPError (http .StatusBadGateway , fmt .Sprintf ("proxy raw, dial error=%v, url=%s" , t .URL , err )))
145+ c .Set ("_error" , echo .NewHTTPError (http .StatusBadGateway , fmt .Sprintf ("proxy raw, dial error=%v, url=%s" , err , t .URL )))
118146 return
119147 }
120148 defer out .Close ()
121149
122150 // Write header
123151 err = r .Write (out )
124152 if err != nil {
125- c .Set ("_error" , echo .NewHTTPError (http .StatusBadGateway , fmt .Sprintf ("proxy raw, request header copy error=%v, url=%s" , t .URL , err )))
153+ c .Set ("_error" , echo .NewHTTPError (http .StatusBadGateway , fmt .Sprintf ("proxy raw, request header copy error=%v, url=%s" , err , t .URL )))
126154 return
127155 }
128156
@@ -136,7 +164,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
136164 go cp (in , out )
137165 err = <- errCh
138166 if err != nil && err != io .EOF {
139- c .Set ("_error" , fmt .Errorf ("proxy raw, copy body error=%v , url=%s" , t .URL , err ))
167+ c .Set ("_error" , fmt .Errorf ("proxy raw, copy body error=%w , url=%s" , err , t .URL ))
140168 }
141169 })
142170}
@@ -200,7 +228,12 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
200228 return b .targets [b .random .Intn (len (b .targets ))]
201229}
202230
203- // Next returns an upstream target using round-robin technique.
231+ // Next returns an upstream target using round-robin technique. In the case
232+ // where a previously failed request is being retried, the round-robin
233+ // balancer will attempt to use the next target relative to the original
234+ // request. If the list of targets held by the balancer is modified while a
235+ // failed request is being retried, it is possible that the balancer will
236+ // return the original failed target.
204237//
205238// Note: `nil` is returned in case upstream target list is empty.
206239func (b * roundRobinBalancer ) Next (c echo.Context ) * ProxyTarget {
@@ -211,13 +244,29 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
211244 } else if len (b .targets ) == 1 {
212245 return b .targets [0 ]
213246 }
214- // reset the index if out of bounds
215- if b .i >= len (b .targets ) {
216- b .i = 0
247+
248+ var i int
249+ const lastIdxKey = "_round_robin_last_index"
250+ // This request is a retry, start from the index of the previous
251+ // target to ensure we don't attempt to retry the request with
252+ // the same failed target
253+ if c .Get (lastIdxKey ) != nil {
254+ i = c .Get (lastIdxKey ).(int )
255+ i ++
256+ if i >= len (b .targets ) {
257+ i = 0
258+ }
259+ } else {
260+ // This is a first time request, use the global index
261+ if b .i >= len (b .targets ) {
262+ b .i = 0
263+ }
264+ i = b .i
265+ b .i ++
217266 }
218- t := b . targets [ b . i ]
219- b . i ++
220- return t
267+
268+ c . Set ( lastIdxKey , i )
269+ return b . targets [ i ]
221270}
222271
223272// Proxy returns a Proxy middleware.
@@ -232,14 +281,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
232281// ProxyWithConfig returns a Proxy middleware with config.
233282// See: `Proxy()`
234283func ProxyWithConfig (config ProxyConfig ) echo.MiddlewareFunc {
284+ if config .Balancer == nil {
285+ panic ("echo: proxy middleware requires balancer" )
286+ }
235287 // Defaults
236288 if config .Skipper == nil {
237289 config .Skipper = DefaultProxyConfig .Skipper
238290 }
239- if config .Balancer == nil {
240- panic ("echo: proxy middleware requires balancer" )
291+ if config .RetryFilter == nil {
292+ config .RetryFilter = func (c echo.Context , e error ) bool {
293+ if httpErr , ok := e .(* echo.HTTPError ); ok {
294+ return httpErr .Code == http .StatusBadGateway
295+ }
296+ return false
297+ }
298+ }
299+ if config .ErrorHandler == nil {
300+ config .ErrorHandler = func (c echo.Context , err error ) error {
301+ return err
302+ }
241303 }
242-
243304 if config .Rewrite != nil {
244305 if config .RegexRewrite == nil {
245306 config .RegexRewrite = make (map [* regexp.Regexp ]string )
@@ -250,28 +311,17 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
250311 }
251312
252313 provider , isTargetProvider := config .Balancer .(TargetProvider )
314+
253315 return func (next echo.HandlerFunc ) echo.HandlerFunc {
254- return func (c echo.Context ) ( err error ) {
316+ return func (c echo.Context ) error {
255317 if config .Skipper (c ) {
256318 return next (c )
257319 }
258320
259321 req := c .Request ()
260322 res := c .Response ()
261-
262- var tgt * ProxyTarget
263- if isTargetProvider {
264- tgt , err = provider .NextTarget (c )
265- if err != nil {
266- return err
267- }
268- } else {
269- tgt = config .Balancer .Next (c )
270- }
271- c .Set (config .ContextKey , tgt )
272-
273323 if err := rewriteURL (config .RegexRewrite , req ); err != nil {
274- return err
324+ return config . ErrorHandler ( c , err )
275325 }
276326
277327 // Fix header
@@ -287,19 +337,49 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
287337 req .Header .Set (echo .HeaderXForwardedFor , c .RealIP ())
288338 }
289339
290- // Proxy
291- switch {
292- case c .IsWebSocket ():
293- proxyRaw (tgt , c ).ServeHTTP (res , req )
294- case req .Header .Get (echo .HeaderAccept ) == "text/event-stream" :
295- default :
296- proxyHTTP (tgt , c , config ).ServeHTTP (res , req )
297- }
298- if e , ok := c .Get ("_error" ).(error ); ok {
299- err = e
300- }
340+ retries := config .RetryCount
341+ for {
342+ var tgt * ProxyTarget
343+ var err error
344+ if isTargetProvider {
345+ tgt , err = provider .NextTarget (c )
346+ if err != nil {
347+ return config .ErrorHandler (c , err )
348+ }
349+ } else {
350+ tgt = config .Balancer .Next (c )
351+ }
301352
302- return
353+ c .Set (config .ContextKey , tgt )
354+
355+ //If retrying a failed request, clear any previous errors from
356+ //context here so that balancers have the option to check for
357+ //errors that occurred using previous target
358+ if retries < config .RetryCount {
359+ c .Set ("_error" , nil )
360+ }
361+
362+ // Proxy
363+ switch {
364+ case c .IsWebSocket ():
365+ proxyRaw (tgt , c ).ServeHTTP (res , req )
366+ case req .Header .Get (echo .HeaderAccept ) == "text/event-stream" :
367+ default :
368+ proxyHTTP (tgt , c , config ).ServeHTTP (res , req )
369+ }
370+
371+ err , hasError := c .Get ("_error" ).(error )
372+ if ! hasError {
373+ return nil
374+ }
375+
376+ retry := retries > 0 && config .RetryFilter (c , err )
377+ if ! retry {
378+ return config .ErrorHandler (c , err )
379+ }
380+
381+ retries --
382+ }
303383 }
304384 }
305385}
0 commit comments