@@ -6,22 +6,21 @@ package web
66
77import (
88 "context"
9- "crypto/rand"
109 "encoding/json"
1110 "errors"
1211 "fmt"
1312 "io"
1413 "log"
1514 "net/http"
1615 "net/netip"
16+ "net/url"
1717 "os"
1818 "path"
19- "path/filepath "
19+ "slices "
2020 "strings"
2121 "sync"
2222 "time"
2323
24- "github.com/gorilla/csrf"
2524 "tailscale.com/client/local"
2625 "tailscale.com/client/tailscale/apitype"
2726 "tailscale.com/clientupdate"
@@ -60,6 +59,12 @@ type Server struct {
6059 cgiMode bool
6160 pathPrefix string
6261
62+ // originOverride is the origin that the web UI is accessible from.
63+ // This value is used in the fallback CSRF checks when Sec-Fetch-Site is not
64+ // available. In this case the application will compare Host and Origin
65+ // header values to determine if the request is from the same origin.
66+ originOverride string
67+
6368 apiHandler http.Handler // serves api endpoints; csrf-protected
6469 assetsHandler http.Handler // serves frontend assets
6570 assetsCleanup func () // called from Server.Shutdown
@@ -150,6 +155,9 @@ type ServerOpts struct {
150155 // as completed.
151156 // This field is required for ManageServerMode mode.
152157 WaitAuthURL func (ctx context.Context , id string , src tailcfg.NodeID ) (* tailcfg.WebClientAuthResponse , error )
158+
159+ // OriginOverride specifies the origin that the web UI will be accessible from if hosted behind a reverse proxy or CGI.
160+ OriginOverride string
153161}
154162
155163// NewServer constructs a new Tailscale web client server.
@@ -169,15 +177,16 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
169177 opts .LocalClient = & local.Client {}
170178 }
171179 s = & Server {
172- mode : opts .Mode ,
173- logf : opts .Logf ,
174- devMode : envknob .Bool ("TS_DEBUG_WEB_CLIENT_DEV" ),
175- lc : opts .LocalClient ,
176- cgiMode : opts .CGIMode ,
177- pathPrefix : opts .PathPrefix ,
178- timeNow : opts .TimeNow ,
179- newAuthURL : opts .NewAuthURL ,
180- waitAuthURL : opts .WaitAuthURL ,
180+ mode : opts .Mode ,
181+ logf : opts .Logf ,
182+ devMode : envknob .Bool ("TS_DEBUG_WEB_CLIENT_DEV" ),
183+ lc : opts .LocalClient ,
184+ cgiMode : opts .CGIMode ,
185+ pathPrefix : opts .PathPrefix ,
186+ timeNow : opts .TimeNow ,
187+ newAuthURL : opts .NewAuthURL ,
188+ waitAuthURL : opts .WaitAuthURL ,
189+ originOverride : opts .OriginOverride ,
181190 }
182191 if opts .PathPrefix != "" {
183192 // Enforce that path prefix always has a single leading '/'
@@ -205,7 +214,7 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
205214
206215 var metric string
207216 s .apiHandler , metric = s .modeAPIHandler (s .mode )
208- s .apiHandler = s .withCSRF (s .apiHandler )
217+ s .apiHandler = s .csrfProtect (s .apiHandler )
209218
210219 // Don't block startup on reporting metric.
211220 // Report in separate go routine with 5 second timeout.
@@ -218,23 +227,64 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
218227 return s , nil
219228}
220229
221- func (s * Server ) withCSRF (h http.Handler ) http.Handler {
222- csrfProtect := csrf .Protect (s .csrfKey (), csrf .Secure (false ))
230+ func (s * Server ) csrfProtect (h http.Handler ) http.Handler {
231+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
232+ // CSRF is not required for GET, HEAD, or OPTIONS requests.
233+ if slices .Contains ([]string {"GET" , "HEAD" , "OPTIONS" }, r .Method ) {
234+ h .ServeHTTP (w , r )
235+ return
236+ }
223237
224- // ref https://github.com/tailscale/tailscale/pull/14822
225- // signal to the CSRF middleware that the request is being served over
226- // plaintext HTTP to skip TLS-only header checks.
227- withSetPlaintext := func (h http.Handler ) http.Handler {
228- return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
229- r = csrf .PlaintextHTTPRequest (r )
238+ // first attempt to use Sec-Fetch-Site header (sent by all modern
239+ // browsers to "potentially trustworthy" origins i.e. localhost or those
240+ // served over HTTPS)
241+ secFetchSite := r .Header .Get ("Sec-Fetch-Site" )
242+ if secFetchSite == "same-origin" {
230243 h .ServeHTTP (w , r )
231- })
232- }
244+ return
245+ } else if secFetchSite != "" {
246+ http .Error (w , fmt .Sprintf ("CSRF request denied with Sec-Fetch-Site %q" , secFetchSite ), http .StatusForbidden )
247+ return
248+ }
249+
250+ // if Sec-Fetch-Site is not available we presume we are operating over HTTP.
251+ // We fall back to comparing the Origin & Host headers.
252+
253+ // use the Host header to determine the expected origin
254+ // (use the override if set to allow for reverse proxying)
255+ host := r .Host
256+ if host == "" {
257+ http .Error (w , "CSRF request denied with no Host header" , http .StatusForbidden )
258+ return
259+ }
260+ if s .originOverride != "" {
261+ host = s .originOverride
262+ }
263+
264+ originHeader := r .Header .Get ("Origin" )
265+ if originHeader == "" {
266+ http .Error (w , "CSRF request denied with no Origin header" , http .StatusForbidden )
267+ return
268+ }
269+ parsedOrigin , err := url .Parse (originHeader )
270+ if err != nil {
271+ http .Error (w , fmt .Sprintf ("CSRF request denied with invalid Origin %q" , r .Header .Get ("Origin" )), http .StatusForbidden )
272+ return
273+ }
274+ origin := parsedOrigin .Host
275+ if origin == "" {
276+ http .Error (w , "CSRF request denied with no host in the Origin header" , http .StatusForbidden )
277+ return
278+ }
279+
280+ if origin != host {
281+ http .Error (w , fmt .Sprintf ("CSRF request denied with mismatched Origin %q and Host %q" , origin , host ), http .StatusForbidden )
282+ return
283+ }
284+
285+ h .ServeHTTP (w , r )
233286
234- // NB: the order of the withSetPlaintext and csrfProtect calls is important
235- // to ensure that we signal to the CSRF middleware that the request is being
236- // served over plaintext HTTP and not over TLS as it presumes by default.
237- return withSetPlaintext (csrfProtect (h ))
287+ })
238288}
239289
240290func (s * Server ) modeAPIHandler (mode ServerMode ) (http.Handler , string ) {
@@ -452,7 +502,6 @@ func (s *Server) authorizeRequest(w http.ResponseWriter, r *http.Request) (ok bo
452502// It should only be called by Server.ServeHTTP, via Server.apiHandler,
453503// which protects the handler using gorilla csrf.
454504func (s * Server ) serveLoginAPI (w http.ResponseWriter , r * http.Request ) {
455- w .Header ().Set ("X-CSRF-Token" , csrf .Token (r ))
456505 switch {
457506 case r .URL .Path == "/api/data" && r .Method == httpm .GET :
458507 s .serveGetNodeData (w , r )
@@ -575,7 +624,6 @@ func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) {
575624 }
576625 }
577626
578- w .Header ().Set ("X-CSRF-Token" , csrf .Token (r ))
579627 path := strings .TrimPrefix (r .URL .Path , "/api" )
580628 switch {
581629 case path == "/data" && r .Method == httpm .GET :
@@ -1276,37 +1324,6 @@ func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request)
12761324 }
12771325}
12781326
1279- // csrfKey returns a key that can be used for CSRF protection.
1280- // If an error occurs during key creation, the error is logged and the active process terminated.
1281- // If the server is running in CGI mode, the key is cached to disk and reused between requests.
1282- // If an error occurs during key storage, the error is logged and the active process terminated.
1283- func (s * Server ) csrfKey () []byte {
1284- csrfFile := filepath .Join (os .TempDir (), "tailscale-web-csrf.key" )
1285-
1286- // if running in CGI mode, try to read from disk, but ignore errors
1287- if s .cgiMode {
1288- key , _ := os .ReadFile (csrfFile )
1289- if len (key ) == 32 {
1290- return key
1291- }
1292- }
1293-
1294- // create a new key
1295- key := make ([]byte , 32 )
1296- if _ , err := rand .Read (key ); err != nil {
1297- log .Fatalf ("error generating CSRF key: %v" , err )
1298- }
1299-
1300- // if running in CGI mode, try to write the newly created key to disk, and exit if it fails.
1301- if s .cgiMode {
1302- if err := os .WriteFile (csrfFile , key , 0600 ); err != nil {
1303- log .Fatalf ("unable to store CSRF key: %v" , err )
1304- }
1305- }
1306-
1307- return key
1308- }
1309-
13101327// enforcePrefix returns a HandlerFunc that enforces a given path prefix is used in requests,
13111328// then strips it before invoking h.
13121329// Unlike http.StripPrefix, it does not return a 404 if the prefix is not present.
0 commit comments