@@ -2,9 +2,7 @@ package middleware
22
33import (
44 "crypto/subtle"
5- "errors"
65 "net/http"
7- "strings"
86 "time"
97
108 "github.com/labstack/echo/v4"
@@ -21,13 +19,15 @@ type (
2119 TokenLength uint8 `yaml:"token_length"`
2220 // Optional. Default value 32.
2321
24- // TokenLookup is a string in the form of "<source>:<key >" that is used
22+ // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name >" that is used
2523 // to extract token from the request.
2624 // Optional. Default value "header:X-CSRF-Token".
2725 // Possible values:
28- // - "header:<name>"
29- // - "form:<name>"
26+ // - "header:<name>" or "header:<name>:<cut-prefix>"
3027 // - "query:<name>"
28+ // - "form:<name>"
29+ // Multiple sources example:
30+ // - "header:X-CSRF-Token,query:csrf"
3131 TokenLookup string `yaml:"token_lookup"`
3232
3333 // Context key to store generated CSRF token into context.
@@ -62,12 +62,11 @@ type (
6262 // Optional. Default value SameSiteDefaultMode.
6363 CookieSameSite http.SameSite `yaml:"cookie_same_site"`
6464 }
65-
66- // csrfTokenExtractor defines a function that takes `echo.Context` and returns
67- // either a token or an error.
68- csrfTokenExtractor func (echo.Context ) (string , error )
6965)
7066
67+ // ErrCSRFInvalid is returned when CSRF check fails
68+ var ErrCSRFInvalid = echo .NewHTTPError (http .StatusForbidden , "invalid csrf token" )
69+
7170var (
7271 // DefaultCSRFConfig is the default CSRF middleware config.
7372 DefaultCSRFConfig = CSRFConfig {
@@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
114113 config .CookieSecure = true
115114 }
116115
117- // Initialize
118- parts := strings .Split (config .TokenLookup , ":" )
119- extractor := csrfTokenFromHeader (parts [1 ])
120- switch parts [0 ] {
121- case "form" :
122- extractor = csrfTokenFromForm (parts [1 ])
123- case "query" :
124- extractor = csrfTokenFromQuery (parts [1 ])
116+ extractors , err := createExtractors (config .TokenLookup , "" )
117+ if err != nil {
118+ panic (err )
125119 }
126120
127121 return func (next echo.HandlerFunc ) echo.HandlerFunc {
@@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
130124 return next (c )
131125 }
132126
133- req := c .Request ()
134- k , err := c .Cookie (config .CookieName )
135127 token := ""
136-
137- // Generate token
138- if err != nil {
139- token = random .String (config .TokenLength )
128+ if k , err := c .Cookie (config .CookieName ); err != nil {
129+ token = random .String (config .TokenLength ) // Generate token
140130 } else {
141- // Reuse token
142- token = k .Value
131+ token = k .Value // Reuse token
143132 }
144133
145- switch req .Method {
134+ switch c . Request () .Method {
146135 case http .MethodGet , http .MethodHead , http .MethodOptions , http .MethodTrace :
147136 default :
148137 // Validate token only for requests which are not defined as 'safe' by RFC7231
149- clientToken , err := extractor (c )
150- if err != nil {
151- return echo .NewHTTPError (http .StatusBadRequest , err .Error ())
138+ var lastExtractorErr error
139+ var lastTokenErr error
140+ outer:
141+ for _ , extractor := range extractors {
142+ clientTokens , err := extractor (c )
143+ if err != nil {
144+ lastExtractorErr = err
145+ continue
146+ }
147+
148+ for _ , clientToken := range clientTokens {
149+ if validateCSRFToken (token , clientToken ) {
150+ lastTokenErr = nil
151+ lastExtractorErr = nil
152+ break outer
153+ }
154+ lastTokenErr = ErrCSRFInvalid
155+ }
152156 }
153- if ! validateCSRFToken (token , clientToken ) {
154- return echo .NewHTTPError (http .StatusForbidden , "invalid csrf token" )
157+ if lastTokenErr != nil {
158+ return lastTokenErr
159+ } else if lastExtractorErr != nil {
160+ // ugly part to preserve backwards compatible errors. someone could rely on them
161+ if lastExtractorErr == errQueryExtractorValueMissing {
162+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , "missing csrf token in the query string" )
163+ } else if lastExtractorErr == errFormExtractorValueMissing {
164+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , "missing csrf token in the form parameter" )
165+ } else if lastExtractorErr == errHeaderExtractorValueMissing {
166+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , "missing csrf token in request header" )
167+ } else {
168+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , lastExtractorErr .Error ())
169+ }
170+ return lastExtractorErr
155171 }
156172 }
157173
@@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
184200 }
185201}
186202
187- // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
188- // provided request header.
189- func csrfTokenFromHeader (header string ) csrfTokenExtractor {
190- return func (c echo.Context ) (string , error ) {
191- return c .Request ().Header .Get (header ), nil
192- }
193- }
194-
195- // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
196- // provided form parameter.
197- func csrfTokenFromForm (param string ) csrfTokenExtractor {
198- return func (c echo.Context ) (string , error ) {
199- token := c .FormValue (param )
200- if token == "" {
201- return "" , errors .New ("missing csrf token in the form parameter" )
202- }
203- return token , nil
204- }
205- }
206-
207- // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
208- // provided query parameter.
209- func csrfTokenFromQuery (param string ) csrfTokenExtractor {
210- return func (c echo.Context ) (string , error ) {
211- token := c .QueryParam (param )
212- if token == "" {
213- return "" , errors .New ("missing csrf token in the query string" )
214- }
215- return token , nil
216- }
217- }
218-
219203func validateCSRFToken (token , clientToken string ) bool {
220204 return subtle .ConstantTimeCompare ([]byte (token ), []byte (clientToken )) == 1
221205}
0 commit comments