Skip to content

Commit 5c87062

Browse files
committed
ACL testing (#1803)
1 parent 72fcb93 commit 5c87062

File tree

16 files changed

+2311
-66
lines changed

16 files changed

+2311
-66
lines changed

cmd/headscale/cli/policy.go

Lines changed: 355 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package cli
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"io"
67
"os"
8+
"strings"
79

810
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
911
"github.com/juanfont/headscale/hscontrol/db"
1012
"github.com/juanfont/headscale/hscontrol/policy"
13+
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
1114
"github.com/juanfont/headscale/hscontrol/types"
1215
"github.com/juanfont/headscale/hscontrol/util"
1316
"github.com/rs/zerolog/log"
@@ -16,7 +19,10 @@ import (
1619
)
1720

1821
const (
19-
bypassFlag = "bypass-grpc-and-access-database-directly"
22+
bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // Not credentials
23+
separatorWidth = 50
24+
outputFormatJSON = "json"
25+
outputFormatJSONLine = "json-line"
2026
)
2127

2228
func init() {
@@ -37,6 +43,17 @@ func init() {
3743
log.Fatal().Err(err).Msg("")
3844
}
3945
policyCmd.AddCommand(checkPolicy)
46+
47+
// Test command flags
48+
testPolicy.Flags().StringP("src", "s", "", "Source alias to test from (user, group, tag, host, or IP)")
49+
testPolicy.Flags().StringSliceP("accept", "a", nil, "Destinations that should be allowed (repeatable, format: host:port)")
50+
testPolicy.Flags().StringSliceP("deny", "d", nil, "Destinations that should be denied (repeatable, format: host:port)")
51+
testPolicy.Flags().StringP("proto", "p", "", "Protocol to test (tcp, udp, icmp)")
52+
testPolicy.Flags().StringP("file", "f", "", "Path to a JSON file with test definitions")
53+
testPolicy.Flags().StringP("policy-file", "", "", "Test against a proposed policy file instead of current policy")
54+
testPolicy.Flags().BoolP("embedded", "e", false, "Run tests embedded in the current policy")
55+
testPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
56+
policyCmd.AddCommand(testPolicy)
4057
}
4158

4259
var policyCmd = &cobra.Command{
@@ -210,3 +227,340 @@ var checkPolicy = &cobra.Command{
210227
SuccessOutput(nil, "Policy is valid", "")
211228
},
212229
}
230+
231+
var testPolicy = &cobra.Command{
232+
Use: "test",
233+
Short: "Test ACL rules",
234+
Long: `Test ACL rules to verify access between sources and destinations.
235+
236+
Examples:
237+
# Test if user can access server
238+
headscale policy test --src "alice@example.com" --accept "tag:server:22"
239+
240+
# Test with deny rules
241+
headscale policy test --src "alice@" --accept "10.0.0.1:80" --deny "10.0.0.2:443"
242+
243+
# Run tests from a JSON file
244+
headscale policy test --file tests.json
245+
246+
# Run embedded tests from current policy
247+
headscale policy test --embedded
248+
249+
# Test against a proposed policy file
250+
headscale policy test --src "alice@" --accept "10.0.0.1:22" --policy-file new-policy.json`,
251+
Run: func(cmd *cobra.Command, args []string) {
252+
output, _ := cmd.Flags().GetString("output")
253+
254+
// Collect tests from various sources
255+
var tests []policyv2.ACLTest
256+
257+
// Get flags
258+
src, _ := cmd.Flags().GetString("src")
259+
accept, _ := cmd.Flags().GetStringSlice("accept")
260+
deny, _ := cmd.Flags().GetStringSlice("deny")
261+
proto, _ := cmd.Flags().GetString("proto")
262+
testFile, _ := cmd.Flags().GetString("file")
263+
policyFile, _ := cmd.Flags().GetString("policy-file")
264+
embedded, _ := cmd.Flags().GetBool("embedded")
265+
bypass, _ := cmd.Flags().GetBool(bypassFlag)
266+
267+
// Build test from command line flags if src is provided
268+
if src != "" {
269+
tests = append(tests, policyv2.ACLTest{
270+
Src: src,
271+
Proto: policyv2.Protocol(proto),
272+
Accept: accept,
273+
Deny: deny,
274+
})
275+
}
276+
277+
// Load tests from file if provided
278+
if testFile != "" {
279+
fileTests, err := loadTestsFromFile(testFile)
280+
if err != nil {
281+
ErrorOutput(err, fmt.Sprintf("Error loading tests from file: %s", err), output)
282+
return
283+
}
284+
tests = append(tests, fileTests...)
285+
}
286+
287+
// Read policy file if provided (for testing against proposed policy)
288+
var policyBytes []byte
289+
if policyFile != "" {
290+
f, err := os.Open(policyFile)
291+
if err != nil {
292+
ErrorOutput(err, fmt.Sprintf("Error opening policy file: %s", err), output)
293+
return
294+
}
295+
defer f.Close()
296+
297+
policyBytes, err = io.ReadAll(f)
298+
if err != nil {
299+
ErrorOutput(err, fmt.Sprintf("Error reading policy file: %s", err), output)
300+
return
301+
}
302+
}
303+
304+
var results policyv2.ACLTestResults
305+
306+
if bypass {
307+
results = runTestsBypass(cmd, output, tests, policyBytes, embedded)
308+
} else {
309+
results = runTestsGRPC(cmd, output, tests, policyBytes, embedded)
310+
}
311+
312+
// Output results
313+
if output == outputFormatJSON || output == outputFormatJSONLine {
314+
SuccessOutput(results, "", output)
315+
} else {
316+
printHumanReadableResults(results)
317+
}
318+
},
319+
}
320+
321+
func loadTestsFromFile(path string) ([]policyv2.ACLTest, error) {
322+
f, err := os.Open(path)
323+
if err != nil {
324+
return nil, err
325+
}
326+
defer f.Close()
327+
328+
var tests []policyv2.ACLTest
329+
330+
decoder := json.NewDecoder(f)
331+
332+
err = decoder.Decode(&tests)
333+
if err != nil {
334+
return nil, err
335+
}
336+
337+
return tests, nil
338+
}
339+
340+
func runTestsBypass(cmd *cobra.Command, output string, tests []policyv2.ACLTest, policyBytes []byte, embedded bool) policyv2.ACLTestResults {
341+
confirm := false
342+
343+
force, _ := cmd.Flags().GetBool("force")
344+
if !force {
345+
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
346+
}
347+
348+
if !confirm && !force {
349+
ErrorOutput(nil, "Aborting command", output)
350+
return policyv2.ACLTestResults{}
351+
}
352+
353+
cfg, err := types.LoadServerConfig()
354+
if err != nil {
355+
ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output)
356+
return policyv2.ACLTestResults{}
357+
}
358+
359+
d, err := db.NewHeadscaleDatabase(
360+
cfg.Database,
361+
cfg.BaseDomain,
362+
nil,
363+
)
364+
if err != nil {
365+
ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output)
366+
return policyv2.ACLTestResults{}
367+
}
368+
369+
users, err := d.ListUsers()
370+
if err != nil {
371+
ErrorOutput(err, fmt.Sprintf("Failed to load users: %s", err), output)
372+
return policyv2.ACLTestResults{}
373+
}
374+
375+
nodes, err := d.ListNodes()
376+
if err != nil {
377+
ErrorOutput(err, fmt.Sprintf("Failed to load nodes: %s", err), output)
378+
return policyv2.ACLTestResults{}
379+
}
380+
381+
// Convert nodes to NodeView slice
382+
nodeViews := make([]types.NodeView, len(nodes))
383+
for i, n := range nodes {
384+
nodeViews[i] = n.View()
385+
}
386+
387+
// Determine which policy to test against
388+
var polBytes []byte
389+
if len(policyBytes) > 0 {
390+
polBytes = policyBytes
391+
} else {
392+
pol, err := d.GetPolicy()
393+
if err != nil {
394+
ErrorOutput(err, fmt.Sprintf("Failed to load policy: %s", err), output)
395+
return policyv2.ACLTestResults{}
396+
}
397+
398+
polBytes = []byte(pol.Data)
399+
}
400+
401+
pm, err := policyv2.NewPolicyManager(polBytes, users, views.SliceOf(nodeViews))
402+
if err != nil {
403+
ErrorOutput(err, fmt.Sprintf("Failed to parse policy: %s", err), output)
404+
return policyv2.ACLTestResults{}
405+
}
406+
407+
// If embedded flag is set, get tests from the policy
408+
if embedded {
409+
pol := pm.Policy()
410+
if pol != nil && len(pol.Tests) > 0 {
411+
tests = append(tests, pol.Tests...)
412+
}
413+
}
414+
415+
if len(tests) == 0 {
416+
ErrorOutput(nil, "No tests to run. Use --src, --file, or --embedded to specify tests.", output)
417+
return policyv2.ACLTestResults{}
418+
}
419+
420+
return pm.RunTests(tests)
421+
}
422+
423+
func runTestsGRPC(_ *cobra.Command, output string, tests []policyv2.ACLTest, policyBytes []byte, embedded bool) policyv2.ACLTestResults {
424+
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
425+
defer cancel()
426+
defer conn.Close()
427+
428+
// If embedded, get tests from current policy first
429+
if embedded {
430+
policyResp, err := client.GetPolicy(ctx, &v1.GetPolicyRequest{})
431+
if err != nil {
432+
ErrorOutput(err, fmt.Sprintf("Failed to get current policy: %s", err), output)
433+
return policyv2.ACLTestResults{}
434+
}
435+
436+
// Parse policy to extract embedded tests
437+
pm, err := policyv2.NewPolicyManager([]byte(policyResp.GetPolicy()), nil, views.Slice[types.NodeView]{})
438+
if err != nil {
439+
ErrorOutput(err, fmt.Sprintf("Failed to parse policy: %s", err), output)
440+
return policyv2.ACLTestResults{}
441+
}
442+
443+
pol := pm.Policy()
444+
if pol != nil && len(pol.Tests) > 0 {
445+
tests = append(tests, pol.Tests...)
446+
}
447+
}
448+
449+
if len(tests) == 0 {
450+
ErrorOutput(nil, "No tests to run. Use --src, --file, or --embedded to specify tests.", output)
451+
return policyv2.ACLTestResults{}
452+
}
453+
454+
// Convert tests to proto format
455+
protoTests := make([]*v1.ACLTest, len(tests))
456+
for i, t := range tests {
457+
protoTests[i] = &v1.ACLTest{
458+
Src: t.Src,
459+
Proto: string(t.Proto),
460+
Accept: t.Accept,
461+
Deny: t.Deny,
462+
}
463+
}
464+
465+
request := &v1.TestACLRequest{
466+
Tests: protoTests,
467+
Policy: string(policyBytes),
468+
}
469+
470+
response, err := client.TestACL(ctx, request)
471+
if err != nil {
472+
ErrorOutput(err, fmt.Sprintf("Failed to run ACL tests: %s", err), output)
473+
return policyv2.ACLTestResults{}
474+
}
475+
476+
// Convert proto response to internal format
477+
results := policyv2.ACLTestResults{
478+
AllPassed: response.GetAllPassed(),
479+
Results: make([]policyv2.ACLTestResult, len(response.GetResults())),
480+
}
481+
482+
for i, r := range response.GetResults() {
483+
results.Results[i] = policyv2.ACLTestResult{
484+
Src: r.GetSrc(),
485+
Passed: r.GetPassed(),
486+
Errors: r.GetErrors(),
487+
AcceptOK: r.GetAcceptOk(),
488+
AcceptFail: r.GetAcceptFail(),
489+
DenyOK: r.GetDenyOk(),
490+
DenyFail: r.GetDenyFail(),
491+
}
492+
}
493+
494+
return results
495+
}
496+
497+
func printHumanReadableResults(results policyv2.ACLTestResults) {
498+
fmt.Println("ACL Test Results")
499+
fmt.Println(strings.Repeat("=", separatorWidth))
500+
fmt.Println()
501+
502+
passedCount := 0
503+
totalCount := len(results.Results)
504+
505+
for _, result := range results.Results {
506+
fmt.Printf("Source: %s\n", result.Src)
507+
fmt.Println()
508+
509+
if len(result.Errors) > 0 {
510+
fmt.Println(" Errors:")
511+
512+
for _, e := range result.Errors {
513+
fmt.Printf(" ! %s\n", e)
514+
}
515+
516+
fmt.Println()
517+
}
518+
519+
if len(result.AcceptOK) > 0 || len(result.AcceptFail) > 0 {
520+
fmt.Println(" Accept Tests:")
521+
522+
for _, dest := range result.AcceptOK {
523+
fmt.Printf(" [PASS] %s - ALLOWED (expected)\n", dest)
524+
}
525+
526+
for _, dest := range result.AcceptFail {
527+
fmt.Printf(" [FAIL] %s - DENIED (expected ALLOWED)\n", dest)
528+
}
529+
530+
fmt.Println()
531+
}
532+
533+
if len(result.DenyOK) > 0 || len(result.DenyFail) > 0 {
534+
fmt.Println(" Deny Tests:")
535+
536+
for _, dest := range result.DenyOK {
537+
fmt.Printf(" [PASS] %s - DENIED (expected)\n", dest)
538+
}
539+
540+
for _, dest := range result.DenyFail {
541+
fmt.Printf(" [FAIL] %s - ALLOWED (expected DENIED)\n", dest)
542+
}
543+
544+
fmt.Println()
545+
}
546+
547+
if result.Passed {
548+
passedCount++
549+
550+
fmt.Println(" Result: PASSED")
551+
} else {
552+
fmt.Println(" Result: FAILED")
553+
}
554+
555+
fmt.Println()
556+
fmt.Println(strings.Repeat("-", separatorWidth))
557+
fmt.Println()
558+
}
559+
560+
// Summary
561+
if results.AllPassed {
562+
fmt.Printf("Overall: PASSED (%d/%d tests passed)\n", passedCount, totalCount)
563+
} else {
564+
fmt.Printf("Overall: FAILED (%d/%d tests passed)\n", passedCount, totalCount)
565+
}
566+
}

0 commit comments

Comments
 (0)