@@ -2,10 +2,10 @@ package middleware
22
33import (
44 "context"
5+ "github.com/labstack/echo/v4"
56 "net/http"
7+ "sync"
68 "time"
7-
8- "github.com/labstack/echo/v4"
99)
1010
1111// ---------------------------------------------------------------------------------------------------------------
@@ -55,29 +55,27 @@ import (
5555// })
5656//
5757
58- type (
59- // TimeoutConfig defines the config for Timeout middleware.
60- TimeoutConfig struct {
61- // Skipper defines a function to skip middleware.
62- Skipper Skipper
63-
64- // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
65- // It can be used to define a custom timeout error message
66- ErrorMessage string
67-
68- // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
69- // request timeouted and we already had sent the error code (503) and message response to the client.
70- // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
71- // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
72- OnTimeoutRouteErrorHandler func (err error , c echo.Context )
73-
74- // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
75- // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
76- // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
77- // difference over 500microseconds (0.5millisecond) response seems to be reliable
78- Timeout time.Duration
79- }
80- )
58+ // TimeoutConfig defines the config for Timeout middleware.
59+ type TimeoutConfig struct {
60+ // Skipper defines a function to skip middleware.
61+ Skipper Skipper
62+
63+ // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
64+ // It can be used to define a custom timeout error message
65+ ErrorMessage string
66+
67+ // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
68+ // request timeouted and we already had sent the error code (503) and message response to the client.
69+ // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
70+ // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
71+ OnTimeoutRouteErrorHandler func (err error , c echo.Context )
72+
73+ // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
74+ // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
75+ // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
76+ // difference over 500microseconds (0.5millisecond) response seems to be reliable
77+ Timeout time.Duration
78+ }
8179
8280var (
8381 // DefaultTimeoutConfig is the default Timeout middleware config.
@@ -94,10 +92,17 @@ func Timeout() echo.MiddlewareFunc {
9492 return TimeoutWithConfig (DefaultTimeoutConfig )
9593}
9694
97- // TimeoutWithConfig returns a Timeout middleware with config.
98- // See: `Timeout()`.
95+ // TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
9996func TimeoutWithConfig (config TimeoutConfig ) echo.MiddlewareFunc {
100- // Defaults
97+ mw , err := config .ToMiddleware ()
98+ if err != nil {
99+ panic (err )
100+ }
101+ return mw
102+ }
103+
104+ // ToMiddleware converts Config to middleware or returns an error for invalid configuration
105+ func (config TimeoutConfig ) ToMiddleware () (echo.MiddlewareFunc , error ) {
101106 if config .Skipper == nil {
102107 config .Skipper = DefaultTimeoutConfig .Skipper
103108 }
@@ -108,26 +113,29 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
108113 return next (c )
109114 }
110115
116+ errChan := make (chan error , 1 )
111117 handlerWrapper := echoHandlerFuncWrapper {
118+ writer : & ignorableWriter {ResponseWriter : c .Response ().Writer },
112119 ctx : c ,
113120 handler : next ,
114- errChan : make ( chan error , 1 ) ,
121+ errChan : errChan ,
115122 errHandler : config .OnTimeoutRouteErrorHandler ,
116123 }
117124 handler := http .TimeoutHandler (handlerWrapper , config .Timeout , config .ErrorMessage )
118- handler .ServeHTTP (c . Response (). Writer , c .Request ())
125+ handler .ServeHTTP (handlerWrapper . writer , c .Request ())
119126
120127 select {
121- case err := <- handlerWrapper . errChan :
128+ case err := <- errChan :
122129 return err
123130 default :
124131 return nil
125132 }
126133 }
127- }
134+ }, nil
128135}
129136
130137type echoHandlerFuncWrapper struct {
138+ writer * ignorableWriter
131139 ctx echo.Context
132140 handler echo.HandlerFunc
133141 errHandler func (err error , c echo.Context )
@@ -160,23 +168,53 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
160168 }
161169 return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
162170 }
163- // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
164- // and should not anymore send additional headers/data
165- // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
166171 if err != nil {
167- // Error must be written into Writer created in `http.TimeoutHandler` so to get Response into `commited` state.
168- // So call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send
169- // status code by itself and after that our tries to write status code will not work anymore and/or create errors in
170- // log about `superfluous response.WriteHeader call from`
171- t .ctx .Error (err )
172- // we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that
173- // global error handler is probably be called twice as `t.ctx.Error` already does that.
174-
175- // NB: later call of the global error handler or middlewares will not take any effect, as echo.Response will be
176- // already marked as `committed` because we called global error handler above.
177- t .ctx .Response ().Writer = originalWriter // make sure we restore before we signal original coroutine about the error
172+ // This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write
173+ // status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes
174+ // create errors in log about `superfluous response.WriteHeader call from`
175+ t .writer .Ignore (true )
176+ t .ctx .Response ().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error
177+ // we pass error from handler to middlewares up in handler chain to act on it if needed.
178178 t .errChan <- err
179179 return
180180 }
181+ // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
182+ // and should not anymore send additional headers/data
183+ // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
181184 t .ctx .Response ().Writer = originalWriter
182185}
186+
187+ // ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This
188+ // is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make
189+ // sure that external code will not be able to write response to the client.
190+ // Writer is coroutine safe for writes.
191+ type ignorableWriter struct {
192+ http.ResponseWriter
193+
194+ lock sync.Mutex
195+ ignoreWrites bool
196+ }
197+
198+ func (w * ignorableWriter ) Ignore (ignore bool ) {
199+ w .lock .Lock ()
200+ w .ignoreWrites = ignore
201+ w .lock .Unlock ()
202+ }
203+
204+ func (w * ignorableWriter ) WriteHeader (code int ) {
205+ w .lock .Lock ()
206+ defer w .lock .Unlock ()
207+ if w .ignoreWrites {
208+ return
209+ }
210+ w .ResponseWriter .WriteHeader (code )
211+ }
212+
213+ func (w * ignorableWriter ) Write (b []byte ) (int , error ) {
214+ w .lock .Lock ()
215+ defer w .lock .Unlock ()
216+ if w .ignoreWrites {
217+ return len (b ), nil
218+ }
219+ return w .ResponseWriter .Write (b )
220+ }
0 commit comments