diff --git a/authorizer.go b/authorizer.go index 5266f7c..011cd32 100644 --- a/authorizer.go +++ b/authorizer.go @@ -1,6 +1,7 @@ package biscuit import ( + "context" "errors" "fmt" "strings" @@ -23,8 +24,8 @@ type Authorizer interface { AddRule(rule Rule) AddCheck(check Check) AddPolicy(policy Policy) - Authorize() error - Query(rule Rule) (FactSet, error) + Authorize(ctx context.Context) error + Query(ctx context.Context, rule Rule) (FactSet, error) Biscuit() *Biscuit Reset() PrintWorld() string @@ -111,7 +112,7 @@ func (v *authorizer) AddPolicy(policy Policy) { v.policies = append(v.policies, policy) } -func (v *authorizer) Authorize() error { +func (v *authorizer) Authorize(ctx context.Context) error { // if we load facts from the verifier before // the token's fact and rules, we might get inconsistent symbols // token ements should first be converted to builder elements @@ -133,7 +134,7 @@ func (v *authorizer) Authorize() error { v.world.AddRule(r.convert(v.symbols)) } - if err := v.world.Run(v.symbols); err != nil { + if err := v.world.Run(ctx, v.symbols); err != nil { return err } v.dirty = true @@ -226,7 +227,7 @@ func (v *authorizer) Authorize() error { block_world.AddRule(r.convert(v.symbols)) } - if err := block_world.Run(v.symbols); err != nil { + if err := block_world.Run(ctx, v.symbols); err != nil { return err } @@ -277,8 +278,8 @@ func (v *authorizer) Authorize() error { } } -func (v *authorizer) Query(rule Rule) (FactSet, error) { - if err := v.world.Run(v.symbols); err != nil { +func (v *authorizer) Query(ctx context.Context, rule Rule) (FactSet, error) { + if err := v.world.Run(ctx, v.symbols); err != nil { return nil, err } v.dirty = true diff --git a/authorizer_test.go b/authorizer_test.go index ead54ba..c1b0981 100644 --- a/authorizer_test.go +++ b/authorizer_test.go @@ -1,6 +1,7 @@ package biscuit import ( + "context" "crypto/ed25519" "crypto/rand" "testing" @@ -9,6 +10,7 @@ import ( ) func TestVerifierDefaultPolicy(t *testing.T) { + ctx := context.Background() rng := rand.Reader publicRoot, privateRoot, _ := ed25519.GenerateKey(rng) @@ -29,15 +31,16 @@ func TestVerifierDefaultPolicy(t *testing.T) { require.NoError(t, err) v.AddPolicy(DefaultDenyPolicy) - err = v.Authorize() + err = v.Authorize(ctx) require.Equal(t, err, ErrPolicyDenied) v.Reset() v.AddPolicy(DefaultAllowPolicy) - require.NoError(t, v.Authorize()) + require.NoError(t, v.Authorize(ctx)) } func TestVerifierPolicies(t *testing.T) { + ctx := context.Background() rng := rand.Reader publicRoot, privateRoot, _ := ed25519.GenerateKey(rng) @@ -83,7 +86,7 @@ func TestVerifierPolicies(t *testing.T) { IDs: []Term{String("some_file.txt")}, }}) - require.NoError(t, v.Authorize()) + require.NoError(t, v.Authorize(ctx)) v, err = b.Authorizer(publicRoot) require.NoError(t, err) @@ -96,7 +99,7 @@ func TestVerifierPolicies(t *testing.T) { Name: "resource", IDs: []Term{String("some_file.txt")}, }}) - require.Equal(t, v.Authorize(), ErrNoMatchingPolicy) + require.Equal(t, v.Authorize(ctx), ErrNoMatchingPolicy) } func TestVerifierSerializeLoad(t *testing.T) { diff --git a/biscuit.go b/biscuit.go index 2f532e7..1068d8c 100644 --- a/biscuit.go +++ b/biscuit.go @@ -2,6 +2,7 @@ package biscuit import ( "bytes" + "context" "crypto/rand" "encoding/binary" @@ -294,6 +295,37 @@ func (b *Biscuit) Seal(rng io.Reader) (*Biscuit, error) { }, nil } +func (b *Biscuit) Lookup(factName string) []string { + symbols := b.symbols.Clone() + datalogFactName := symbols.Sym(factName) + + if b.authority.facts == nil { + return nil + } + + for _, f := range *b.authority.facts { + if f.Name != datalogFactName { + continue + } + + terms := make([]string, len(f.Terms)) + for i, term := range f.Terms { + switch t := term.(type) { + case datalog.String: + terms[i] = symbols.Str(t) + case datalog.Variable: + terms[i] = symbols.Var(t) + default: + terms[i] = t.String() + } + } + + return terms + } + + return nil +} + type ( // A PublickKeyByIDProjection inspects an optional ID for a public key and returns the // corresponding public key, if any. If it doesn't recognize the ID or can't find the public @@ -568,7 +600,7 @@ func (b *Biscuit) checkRootKey(root ed25519.PublicKey) error { return nil }*/ -func (b *Biscuit) generateWorld(symbols *datalog.SymbolTable) (*datalog.World, error) { +func (b *Biscuit) generateWorld(ctx context.Context, symbols *datalog.SymbolTable) (*datalog.World, error) { world := datalog.NewWorld() for _, fact := range *b.authority.facts { @@ -589,7 +621,7 @@ func (b *Biscuit) generateWorld(symbols *datalog.SymbolTable) (*datalog.World, e } } - if err := world.Run(symbols); err != nil { + if err := world.Run(ctx, symbols); err != nil { return nil, err } diff --git a/biscuit_test.go b/biscuit_test.go index 3edaf09..c6391e6 100644 --- a/biscuit_test.go +++ b/biscuit_test.go @@ -1,6 +1,7 @@ package biscuit import ( + "context" "crypto/ed25519" "crypto/rand" "fmt" @@ -11,6 +12,7 @@ import ( ) func TestBiscuit(t *testing.T) { + ctx := context.Background() rng := rand.Reader const rootKeyID = 123 const contextText = "current_context" @@ -106,21 +108,21 @@ func TestBiscuit(t *testing.T) { v3.AddFact(Fact{Predicate: Predicate{Name: "resource", IDs: []Term{String("/a/file1")}}}) v3.AddFact(Fact{Predicate: Predicate{Name: "operation", IDs: []Term{String("read")}}}) v3.AddPolicy(DefaultAllowPolicy) - require.NoError(t, v3.Authorize()) + require.NoError(t, v3.Authorize(ctx)) v3, err = b3deser.AuthorizerFor(WithSingularRootPublicKey(publicRoot)) require.NoError(t, err) v3.AddFact(Fact{Predicate: Predicate{Name: "resource", IDs: []Term{String("/a/file2")}}}) v3.AddFact(Fact{Predicate: Predicate{Name: "operation", IDs: []Term{String("read")}}}) v3.AddPolicy(DefaultAllowPolicy) - require.Error(t, v3.Authorize()) + require.Error(t, v3.Authorize(ctx)) v3, err = b3deser.AuthorizerFor(WithSingularRootPublicKey(publicRoot)) require.NoError(t, err) v3.AddFact(Fact{Predicate: Predicate{Name: "resource", IDs: []Term{String("/a/file1")}}}) v3.AddFact(Fact{Predicate: Predicate{Name: "operation", IDs: []Term{String("write")}}}) v3.AddPolicy(DefaultAllowPolicy) - require.Error(t, v3.Authorize()) + require.Error(t, v3.Authorize(ctx)) } func TestSealedBiscuit(t *testing.T) { @@ -181,6 +183,7 @@ func TestSealedBiscuit(t *testing.T) { } func TestBiscuitRules(t *testing.T) { + ctx := context.Background() rng := rand.Reader publicRoot, privateRoot, _ := ed25519.GenerateKey(rng) @@ -222,7 +225,7 @@ func TestBiscuitRules(t *testing.T) { // b1 should allow alice & bob only // v, err := b1.Verify(publicRoot) // require.NoError(t, err) - verifyOwner(t, *b1, publicRoot, map[string]bool{"alice": true, "bob": true, "eve": false}) + verifyOwner(ctx, t, *b1, publicRoot, map[string]bool{"alice": true, "bob": true, "eve": false}) block := b1.CreateBlock() block.AddCheck(Check{ @@ -255,10 +258,10 @@ func TestBiscuitRules(t *testing.T) { // b2 should now only allow alice // v, err = b2.Verify(publicRoot) // require.NoError(t, err) - verifyOwner(t, *b2, publicRoot, map[string]bool{"alice": true, "bob": false, "eve": false}) + verifyOwner(ctx, t, *b2, publicRoot, map[string]bool{"alice": true, "bob": false, "eve": false}) } -func verifyOwner(t *testing.T, b Biscuit, publicRoot ed25519.PublicKey, owners map[string]bool) { +func verifyOwner(ctx context.Context, t *testing.T, b Biscuit, publicRoot ed25519.PublicKey, owners map[string]bool) { for user, valid := range owners { v, err := b.AuthorizerFor(WithSingularRootPublicKey(publicRoot)) require.NoError(t, err) @@ -278,9 +281,9 @@ func verifyOwner(t *testing.T, b Biscuit, publicRoot ed25519.PublicKey, owners m v.AddPolicy(DefaultAllowPolicy) if valid { - require.NoError(t, v.Authorize()) + require.NoError(t, v.Authorize(ctx)) } else { - require.Error(t, v.Authorize()) + require.Error(t, v.Authorize(ctx)) } }) } @@ -317,6 +320,7 @@ func TestCheckRootKey(t *testing.T) { } func TestGenerateWorld(t *testing.T) { + ctx := context.Background() rng := rand.Reader _, privateRoot, _ := ed25519.GenerateKey(rng) @@ -349,7 +353,7 @@ func TestGenerateWorld(t *testing.T) { require.NoError(t, err) StringTable := (build.(*builderOptions)).symbols - world, err := b.generateWorld(defaultSymbolTable.Clone()) + world, err := b.generateWorld(ctx, defaultSymbolTable.Clone()) require.NoError(t, err) expectedWorld := datalog.NewWorld() @@ -376,7 +380,7 @@ func TestGenerateWorld(t *testing.T) { require.NoError(t, err) allStrings := append(*StringTable, *(blockBuild.(*blockBuilder)).symbols...) - world, err = b2.generateWorld(&allStrings) + world, err = b2.generateWorld(ctx, &allStrings) require.NoError(t, err) expectedWorld = datalog.NewWorld() @@ -575,6 +579,7 @@ func TestGetBlockID(t *testing.T) { } func TestInvalidRuleGeneration(t *testing.T) { + ctx := context.Background() rng := rand.Reader publicRoot, privateRoot, _ := ed25519.GenerateKey(rng) builder := NewBuilder(privateRoot) @@ -612,7 +617,7 @@ func TestInvalidRuleGeneration(t *testing.T) { IDs: []Term{String("write")}, }}) - err = verifier.Authorize() + err = verifier.Authorize(ctx) t.Log(verifier.PrintWorld()) require.Error(t, err) } diff --git a/datalog/datalog.go b/datalog/datalog.go index 72fc29b..42a97f2 100644 --- a/datalog/datalog.go +++ b/datalog/datalog.go @@ -357,9 +357,9 @@ func (w *World) Rules() []Rule { return w.rules } -func (w *World) Run(syms *SymbolTable) error { +func (w *World) Run(ctx context.Context, syms *SymbolTable) error { done := make(chan error) - ctx, cancel := context.WithTimeout(context.Background(), w.runLimits.maxDuration) + ctx, cancel := context.WithTimeout(ctx, w.runLimits.maxDuration) defer cancel() go func() { diff --git a/datalog/datalog_test.go b/datalog/datalog_test.go index 22eb6a0..151415c 100644 --- a/datalog/datalog_test.go +++ b/datalog/datalog_test.go @@ -1,6 +1,7 @@ package datalog import ( + "context" "crypto/rand" "crypto/sha256" "testing" @@ -19,6 +20,8 @@ func hashVar(s string) Variable { } func TestFamily(t *testing.T) { + ctx := context.Background() + w := NewWorld() syms := &SymbolTable{} dbg := SymbolDebugger{syms} @@ -57,12 +60,12 @@ func TestFamily(t *testing.T) { t.Logf("adding r2: %s", dbg.Rule(r2)) w.AddRule(r2) - if err := w.Run(syms); err != nil { + if err := w.Run(ctx, syms); err != nil { t.Error(err) } w.AddFact(Fact{Predicate{parent, []Term{c, e}}}) - if err := w.Run(syms); err != nil { + if err := w.Run(ctx, syms); err != nil { t.Error(err) } @@ -480,6 +483,7 @@ func TestSetEqual(t *testing.T) { } func TestWorldRunLimits(t *testing.T) { + ctx := context.Background() syms := &SymbolTable{} a := syms.Insert("A") b := syms.Insert("B") @@ -550,6 +554,6 @@ func TestWorldRunLimits(t *testing.T) { } w.AddRule(r1) - require.Equal(t, tc.expectedErr, w.Run(syms)) + require.Equal(t, tc.expectedErr, w.Run(ctx, syms)) } } diff --git a/example_test.go b/example_test.go index 07aeb31..aea2233 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package biscuit_test import ( + "context" "crypto/ed25519" "crypto/rand" "fmt" @@ -10,6 +11,7 @@ import ( ) func ExampleBiscuit() { + ctx := context.Background() rng := rand.Reader publicRoot, privateRoot, _ := ed25519.GenerateKey(rng) @@ -88,7 +90,7 @@ func ExampleBiscuit() { } v1.AddAuthorizer(authorizer) - if err := v1.Authorize(); err != nil { + if err := v1.Authorize(ctx); err != nil { // fmt.Println(v1.PrintWorld()) fmt.Println("forbidden to read /a/file1.txt") @@ -111,7 +113,7 @@ func ExampleBiscuit() { } v1.AddAuthorizer(authorizer) - if err := v1.Authorize(); err != nil { + if err := v1.Authorize(ctx); err != nil { fmt.Println("forbidden to write /a/file1.txt") } else { fmt.Println("allowed to write /a/file1.txt") diff --git a/samples/samples_test.go b/samples/samples_test.go index 78bb2a5..1345cd1 100644 --- a/samples/samples_test.go +++ b/samples/samples_test.go @@ -1,6 +1,7 @@ package biscuittest import ( + "context" "crypto/ed25519" "crypto/rand" "encoding/hex" @@ -139,7 +140,7 @@ type Validation struct { RevocationIds []string `json:"revocation_ids"` } -func CheckSample(root_key ed25519.PublicKey, c TestCase, t *testing.T) { +func CheckSample(ctx context.Context, root_key ed25519.PublicKey, c TestCase, t *testing.T) { // all these contain v4 blocks, which are not supported yet if c.Filename == "test024_third_party.bc" || c.Filename == "test025_check_all.bc" || @@ -161,7 +162,7 @@ func CheckSample(root_key ed25519.PublicKey, c TestCase, t *testing.T) { } for _, v := range c.Validations { - CompareResult(root_key, c.Filename, *token, v, t) + CompareResult(ctx, root_key, c.Filename, *token, v, t) } } else { @@ -200,7 +201,7 @@ func CompareBlocks(token biscuit.Biscuit, blocks []Block, t *testing.T) { require.Equal(t, sample, rebuilt.Code()) } -func CompareResult(root_key ed25519.PublicKey, filename string, token biscuit.Biscuit, v Validation, t *testing.T) { +func CompareResult(ctx context.Context, root_key ed25519.PublicKey, filename string, token biscuit.Biscuit, v Validation, t *testing.T) { p := parser.New() authorizer_code, err := p.Authorizer(v.AuthorizerCode, nil) require.NoError(t, err) @@ -210,7 +211,7 @@ func CompareResult(root_key ed25519.PublicKey, filename string, token biscuit.Bi CompareError(err, v.Result.Err, t) } else { authorizer.AddAuthorizer(authorizer_code) - err = authorizer.Authorize() + err = authorizer.Authorize(ctx) if err != nil { CompareError(err, v.Result.Err, t) } else { @@ -241,6 +242,7 @@ func CompareError(authorization_error error, sample_error *BiscuitError, t *test } func TestReadSamples(t *testing.T) { + ctx := context.Background() b, err := os.ReadFile("./data/current/samples.json") require.NoError(t, err) var samples Samples @@ -251,7 +253,7 @@ func TestReadSamples(t *testing.T) { require.NoError(t, err) fmt.Printf("Checking %d samples\n", len(samples.TestCases)) for _, v := range samples.TestCases { - t.Run(v.Filename, func(t *testing.T) { CheckSample(root_key, v, t) }) + t.Run(v.Filename, func(t *testing.T) { CheckSample(ctx, root_key, v, t) }) } } diff --git a/website_test.go b/website_test.go index f3a275c..2212a60 100644 --- a/website_test.go +++ b/website_test.go @@ -1,6 +1,7 @@ package biscuit_test import ( + "context" "crypto/ed25519" "crypto/rand" "fmt" @@ -48,7 +49,7 @@ func CreateToken(root *ed25519.PrivateKey) (*biscuit.Biscuit, error) { return token, nil } -func Authorize(token *biscuit.Biscuit, root *ed25519.PublicKey) error { +func Authorize(ctx context.Context, token *biscuit.Biscuit, root *ed25519.PublicKey) error { authorizer, err := token.Authorizer(*root) if err != nil { return fmt.Errorf("failed to create verifier: %v", err) @@ -72,7 +73,7 @@ func Authorize(token *biscuit.Biscuit, root *ed25519.PublicKey) error { } authorizer.AddPolicy(policy) - return authorizer.Authorize() + return authorizer.Authorize(ctx) } func Attenuate(serializedToken []byte, root *ed25519.PublicKey) ([]byte, error) { @@ -105,11 +106,11 @@ func Seal(b *biscuit.Biscuit, rng io.Reader) (*biscuit.Biscuit, error) { return b.Seal(rng) } -func Query(authorizer biscuit.Authorizer) (biscuit.FactSet, error) { +func Query(ctx context.Context, authorizer biscuit.Authorizer) (biscuit.FactSet, error) { rule, err := parser.FromStringRule(`data($name, $id) <- user($name, $id`) if err != nil { return nil, fmt.Errorf("failed to parse check: %v", err) } - return authorizer.Query(rule) + return authorizer.Query(ctx, rule) }