11package cli
22
33import (
4+ "encoding/json"
45 "fmt"
56 "io"
67 "os"
8+ "strings"
79
810 v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
911 "github.com/juanfont/headscale/hscontrol/db"
1012 "github.com/juanfont/headscale/hscontrol/policy"
13+ policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
1114 "github.com/juanfont/headscale/hscontrol/types"
1215 "github.com/juanfont/headscale/hscontrol/util"
1316 "github.com/rs/zerolog/log"
@@ -16,7 +19,10 @@ import (
1619)
1720
1821const (
19- bypassFlag = "bypass-grpc-and-access-database-directly"
22+ bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // Not credentials
23+ separatorWidth = 50
24+ outputFormatJSON = "json"
25+ outputFormatJSONLine = "json-line"
2026)
2127
2228func init () {
@@ -37,6 +43,17 @@ func init() {
3743 log .Fatal ().Err (err ).Msg ("" )
3844 }
3945 policyCmd .AddCommand (checkPolicy )
46+
47+ // Test command flags
48+ testPolicy .Flags ().StringP ("src" , "s" , "" , "Source alias to test from (user, group, tag, host, or IP)" )
49+ testPolicy .Flags ().StringSliceP ("accept" , "a" , nil , "Destinations that should be allowed (repeatable, format: host:port)" )
50+ testPolicy .Flags ().StringSliceP ("deny" , "d" , nil , "Destinations that should be denied (repeatable, format: host:port)" )
51+ testPolicy .Flags ().StringP ("proto" , "p" , "" , "Protocol to test (tcp, udp, icmp)" )
52+ testPolicy .Flags ().StringP ("file" , "f" , "" , "Path to a JSON file with test definitions" )
53+ testPolicy .Flags ().StringP ("policy-file" , "" , "" , "Test against a proposed policy file instead of current policy" )
54+ testPolicy .Flags ().BoolP ("embedded" , "e" , false , "Run tests embedded in the current policy" )
55+ testPolicy .Flags ().BoolP (bypassFlag , "" , false , "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running" )
56+ policyCmd .AddCommand (testPolicy )
4057}
4158
4259var policyCmd = & cobra.Command {
@@ -210,3 +227,340 @@ var checkPolicy = &cobra.Command{
210227 SuccessOutput (nil , "Policy is valid" , "" )
211228 },
212229}
230+
231+ var testPolicy = & cobra.Command {
232+ Use : "test" ,
233+ Short : "Test ACL rules" ,
234+ Long : `Test ACL rules to verify access between sources and destinations.
235+
236+ Examples:
237+ # Test if user can access server
238+ headscale policy test --src "alice@example.com" --accept "tag:server:22"
239+
240+ # Test with deny rules
241+ headscale policy test --src "alice@" --accept "10.0.0.1:80" --deny "10.0.0.2:443"
242+
243+ # Run tests from a JSON file
244+ headscale policy test --file tests.json
245+
246+ # Run embedded tests from current policy
247+ headscale policy test --embedded
248+
249+ # Test against a proposed policy file
250+ headscale policy test --src "alice@" --accept "10.0.0.1:22" --policy-file new-policy.json` ,
251+ Run : func (cmd * cobra.Command , args []string ) {
252+ output , _ := cmd .Flags ().GetString ("output" )
253+
254+ // Collect tests from various sources
255+ var tests []policyv2.ACLTest
256+
257+ // Get flags
258+ src , _ := cmd .Flags ().GetString ("src" )
259+ accept , _ := cmd .Flags ().GetStringSlice ("accept" )
260+ deny , _ := cmd .Flags ().GetStringSlice ("deny" )
261+ proto , _ := cmd .Flags ().GetString ("proto" )
262+ testFile , _ := cmd .Flags ().GetString ("file" )
263+ policyFile , _ := cmd .Flags ().GetString ("policy-file" )
264+ embedded , _ := cmd .Flags ().GetBool ("embedded" )
265+ bypass , _ := cmd .Flags ().GetBool (bypassFlag )
266+
267+ // Build test from command line flags if src is provided
268+ if src != "" {
269+ tests = append (tests , policyv2.ACLTest {
270+ Src : src ,
271+ Proto : policyv2 .Protocol (proto ),
272+ Accept : accept ,
273+ Deny : deny ,
274+ })
275+ }
276+
277+ // Load tests from file if provided
278+ if testFile != "" {
279+ fileTests , err := loadTestsFromFile (testFile )
280+ if err != nil {
281+ ErrorOutput (err , fmt .Sprintf ("Error loading tests from file: %s" , err ), output )
282+ return
283+ }
284+ tests = append (tests , fileTests ... )
285+ }
286+
287+ // Read policy file if provided (for testing against proposed policy)
288+ var policyBytes []byte
289+ if policyFile != "" {
290+ f , err := os .Open (policyFile )
291+ if err != nil {
292+ ErrorOutput (err , fmt .Sprintf ("Error opening policy file: %s" , err ), output )
293+ return
294+ }
295+ defer f .Close ()
296+
297+ policyBytes , err = io .ReadAll (f )
298+ if err != nil {
299+ ErrorOutput (err , fmt .Sprintf ("Error reading policy file: %s" , err ), output )
300+ return
301+ }
302+ }
303+
304+ var results policyv2.ACLTestResults
305+
306+ if bypass {
307+ results = runTestsBypass (cmd , output , tests , policyBytes , embedded )
308+ } else {
309+ results = runTestsGRPC (cmd , output , tests , policyBytes , embedded )
310+ }
311+
312+ // Output results
313+ if output == outputFormatJSON || output == outputFormatJSONLine {
314+ SuccessOutput (results , "" , output )
315+ } else {
316+ printHumanReadableResults (results )
317+ }
318+ },
319+ }
320+
321+ func loadTestsFromFile (path string ) ([]policyv2.ACLTest , error ) {
322+ f , err := os .Open (path )
323+ if err != nil {
324+ return nil , err
325+ }
326+ defer f .Close ()
327+
328+ var tests []policyv2.ACLTest
329+
330+ decoder := json .NewDecoder (f )
331+
332+ err = decoder .Decode (& tests )
333+ if err != nil {
334+ return nil , err
335+ }
336+
337+ return tests , nil
338+ }
339+
340+ func runTestsBypass (cmd * cobra.Command , output string , tests []policyv2.ACLTest , policyBytes []byte , embedded bool ) policyv2.ACLTestResults {
341+ confirm := false
342+
343+ force , _ := cmd .Flags ().GetBool ("force" )
344+ if ! force {
345+ confirm = util .YesNo ("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?" )
346+ }
347+
348+ if ! confirm && ! force {
349+ ErrorOutput (nil , "Aborting command" , output )
350+ return policyv2.ACLTestResults {}
351+ }
352+
353+ cfg , err := types .LoadServerConfig ()
354+ if err != nil {
355+ ErrorOutput (err , fmt .Sprintf ("Failed loading config: %s" , err ), output )
356+ return policyv2.ACLTestResults {}
357+ }
358+
359+ d , err := db .NewHeadscaleDatabase (
360+ cfg .Database ,
361+ cfg .BaseDomain ,
362+ nil ,
363+ )
364+ if err != nil {
365+ ErrorOutput (err , fmt .Sprintf ("Failed to open database: %s" , err ), output )
366+ return policyv2.ACLTestResults {}
367+ }
368+
369+ users , err := d .ListUsers ()
370+ if err != nil {
371+ ErrorOutput (err , fmt .Sprintf ("Failed to load users: %s" , err ), output )
372+ return policyv2.ACLTestResults {}
373+ }
374+
375+ nodes , err := d .ListNodes ()
376+ if err != nil {
377+ ErrorOutput (err , fmt .Sprintf ("Failed to load nodes: %s" , err ), output )
378+ return policyv2.ACLTestResults {}
379+ }
380+
381+ // Convert nodes to NodeView slice
382+ nodeViews := make ([]types.NodeView , len (nodes ))
383+ for i , n := range nodes {
384+ nodeViews [i ] = n .View ()
385+ }
386+
387+ // Determine which policy to test against
388+ var polBytes []byte
389+ if len (policyBytes ) > 0 {
390+ polBytes = policyBytes
391+ } else {
392+ pol , err := d .GetPolicy ()
393+ if err != nil {
394+ ErrorOutput (err , fmt .Sprintf ("Failed to load policy: %s" , err ), output )
395+ return policyv2.ACLTestResults {}
396+ }
397+
398+ polBytes = []byte (pol .Data )
399+ }
400+
401+ pm , err := policyv2 .NewPolicyManager (polBytes , users , views .SliceOf (nodeViews ))
402+ if err != nil {
403+ ErrorOutput (err , fmt .Sprintf ("Failed to parse policy: %s" , err ), output )
404+ return policyv2.ACLTestResults {}
405+ }
406+
407+ // If embedded flag is set, get tests from the policy
408+ if embedded {
409+ pol := pm .Policy ()
410+ if pol != nil && len (pol .Tests ) > 0 {
411+ tests = append (tests , pol .Tests ... )
412+ }
413+ }
414+
415+ if len (tests ) == 0 {
416+ ErrorOutput (nil , "No tests to run. Use --src, --file, or --embedded to specify tests." , output )
417+ return policyv2.ACLTestResults {}
418+ }
419+
420+ return pm .RunTests (tests )
421+ }
422+
423+ func runTestsGRPC (_ * cobra.Command , output string , tests []policyv2.ACLTest , policyBytes []byte , embedded bool ) policyv2.ACLTestResults {
424+ ctx , client , conn , cancel := newHeadscaleCLIWithConfig ()
425+ defer cancel ()
426+ defer conn .Close ()
427+
428+ // If embedded, get tests from current policy first
429+ if embedded {
430+ policyResp , err := client .GetPolicy (ctx , & v1.GetPolicyRequest {})
431+ if err != nil {
432+ ErrorOutput (err , fmt .Sprintf ("Failed to get current policy: %s" , err ), output )
433+ return policyv2.ACLTestResults {}
434+ }
435+
436+ // Parse policy to extract embedded tests
437+ pm , err := policyv2 .NewPolicyManager ([]byte (policyResp .GetPolicy ()), nil , views.Slice [types.NodeView ]{})
438+ if err != nil {
439+ ErrorOutput (err , fmt .Sprintf ("Failed to parse policy: %s" , err ), output )
440+ return policyv2.ACLTestResults {}
441+ }
442+
443+ pol := pm .Policy ()
444+ if pol != nil && len (pol .Tests ) > 0 {
445+ tests = append (tests , pol .Tests ... )
446+ }
447+ }
448+
449+ if len (tests ) == 0 {
450+ ErrorOutput (nil , "No tests to run. Use --src, --file, or --embedded to specify tests." , output )
451+ return policyv2.ACLTestResults {}
452+ }
453+
454+ // Convert tests to proto format
455+ protoTests := make ([]* v1.ACLTest , len (tests ))
456+ for i , t := range tests {
457+ protoTests [i ] = & v1.ACLTest {
458+ Src : t .Src ,
459+ Proto : string (t .Proto ),
460+ Accept : t .Accept ,
461+ Deny : t .Deny ,
462+ }
463+ }
464+
465+ request := & v1.TestACLRequest {
466+ Tests : protoTests ,
467+ Policy : string (policyBytes ),
468+ }
469+
470+ response , err := client .TestACL (ctx , request )
471+ if err != nil {
472+ ErrorOutput (err , fmt .Sprintf ("Failed to run ACL tests: %s" , err ), output )
473+ return policyv2.ACLTestResults {}
474+ }
475+
476+ // Convert proto response to internal format
477+ results := policyv2.ACLTestResults {
478+ AllPassed : response .GetAllPassed (),
479+ Results : make ([]policyv2.ACLTestResult , len (response .GetResults ())),
480+ }
481+
482+ for i , r := range response .GetResults () {
483+ results .Results [i ] = policyv2.ACLTestResult {
484+ Src : r .GetSrc (),
485+ Passed : r .GetPassed (),
486+ Errors : r .GetErrors (),
487+ AcceptOK : r .GetAcceptOk (),
488+ AcceptFail : r .GetAcceptFail (),
489+ DenyOK : r .GetDenyOk (),
490+ DenyFail : r .GetDenyFail (),
491+ }
492+ }
493+
494+ return results
495+ }
496+
497+ func printHumanReadableResults (results policyv2.ACLTestResults ) {
498+ fmt .Println ("ACL Test Results" )
499+ fmt .Println (strings .Repeat ("=" , separatorWidth ))
500+ fmt .Println ()
501+
502+ passedCount := 0
503+ totalCount := len (results .Results )
504+
505+ for _ , result := range results .Results {
506+ fmt .Printf ("Source: %s\n " , result .Src )
507+ fmt .Println ()
508+
509+ if len (result .Errors ) > 0 {
510+ fmt .Println (" Errors:" )
511+
512+ for _ , e := range result .Errors {
513+ fmt .Printf (" ! %s\n " , e )
514+ }
515+
516+ fmt .Println ()
517+ }
518+
519+ if len (result .AcceptOK ) > 0 || len (result .AcceptFail ) > 0 {
520+ fmt .Println (" Accept Tests:" )
521+
522+ for _ , dest := range result .AcceptOK {
523+ fmt .Printf (" [PASS] %s - ALLOWED (expected)\n " , dest )
524+ }
525+
526+ for _ , dest := range result .AcceptFail {
527+ fmt .Printf (" [FAIL] %s - DENIED (expected ALLOWED)\n " , dest )
528+ }
529+
530+ fmt .Println ()
531+ }
532+
533+ if len (result .DenyOK ) > 0 || len (result .DenyFail ) > 0 {
534+ fmt .Println (" Deny Tests:" )
535+
536+ for _ , dest := range result .DenyOK {
537+ fmt .Printf (" [PASS] %s - DENIED (expected)\n " , dest )
538+ }
539+
540+ for _ , dest := range result .DenyFail {
541+ fmt .Printf (" [FAIL] %s - ALLOWED (expected DENIED)\n " , dest )
542+ }
543+
544+ fmt .Println ()
545+ }
546+
547+ if result .Passed {
548+ passedCount ++
549+
550+ fmt .Println (" Result: PASSED" )
551+ } else {
552+ fmt .Println (" Result: FAILED" )
553+ }
554+
555+ fmt .Println ()
556+ fmt .Println (strings .Repeat ("-" , separatorWidth ))
557+ fmt .Println ()
558+ }
559+
560+ // Summary
561+ if results .AllPassed {
562+ fmt .Printf ("Overall: PASSED (%d/%d tests passed)\n " , passedCount , totalCount )
563+ } else {
564+ fmt .Printf ("Overall: FAILED (%d/%d tests passed)\n " , passedCount , totalCount )
565+ }
566+ }
0 commit comments