Skip to content

Commit 3fa88f3

Browse files
authored
Fix composable schema interactions with expiration feature (#2780)
1 parent 7519ee2 commit 3fa88f3

File tree

10 files changed

+96
-59
lines changed

10 files changed

+96
-59
lines changed

pkg/composableschemadsl/compiler/compiler.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"github.com/authzed/spicedb/pkg/composableschemadsl/input"
1212
"github.com/authzed/spicedb/pkg/composableschemadsl/parser"
1313
"github.com/authzed/spicedb/pkg/genutil/mapz"
14-
"github.com/authzed/spicedb/pkg/genutil/slicez"
1514
core "github.com/authzed/spicedb/pkg/proto/core/v1"
1615
)
1716

@@ -55,7 +54,7 @@ func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, posit
5554
type config struct {
5655
skipValidation bool
5756
objectTypePrefix *string
58-
allowedFlags []string
57+
allowedFlags *mapz.Set[string]
5958
caveatTypeSet *caveattypes.TypeSet
6059

6160
// In an import context, this is the folder containing
@@ -90,9 +89,7 @@ const expirationFlag = "expiration"
9089

9190
func DisallowExpirationFlag() Option {
9291
return func(cfg *config) {
93-
cfg.allowedFlags = slicez.Filter(cfg.allowedFlags, func(s string) bool {
94-
return s != expirationFlag
95-
})
92+
cfg.allowedFlags.Delete(expirationFlag)
9693
}
9794
}
9895

@@ -109,11 +106,11 @@ type ObjectPrefixOption func(*config)
109106
// Compile compilers the input schema into a set of namespace definition protos.
110107
func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
111108
cfg := &config{
112-
allowedFlags: make([]string, 0, 1),
109+
allowedFlags: mapz.NewSet[string](),
113110
}
114111

115112
// Enable `expiration` flag by default.
116-
cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag)
113+
cfg.allowedFlags.Add(expirationFlag)
117114

118115
prefix(cfg) // required option
119116

@@ -146,6 +143,7 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co
146143
schemaString: schema.SchemaString,
147144
skipValidate: cfg.skipValidation,
148145
allowedFlags: cfg.allowedFlags,
146+
enabledFlags: mapz.NewSet[string](),
149147
existingNames: mapz.NewSet[string](),
150148
compiledPartials: initialCompiledPartials,
151149
unresolvedPartials: mapz.NewMultiMap[string, *dslNode](),

pkg/composableschemadsl/compiler/compiler_test.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,17 +1254,23 @@ func TestCompile(t *testing.T) {
12541254
},
12551255
},
12561256
{
1257-
"duplicate use pragmas",
1257+
"wildcard relation with expiration trait",
12581258
withTenantPrefix,
1259-
`
1260-
use expiration
1261-
use expiration
1262-
1259+
`use expiration
1260+
12631261
definition simple {
1264-
relation viewer: user with expiration
1262+
relation viewer: user:* with expiration
12651263
}`,
1266-
`found duplicate use flag`,
1267-
[]SchemaDefinition{},
1264+
"",
1265+
[]SchemaDefinition{
1266+
namespace.Namespace("sometenant/simple",
1267+
namespace.MustRelation("viewer", nil,
1268+
namespace.WithExpiration(
1269+
namespace.AllowedPublicNamespace("sometenant/user"),
1270+
),
1271+
),
1272+
),
1273+
},
12681274
},
12691275
{
12701276
"expiration use without use expiration",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
use expiration
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
use expiration
2+
3+
definition test/user {}
4+
5+
definition test/document {
6+
relation viewer: test/user with expiration | test/user
7+
relation public_reader: test/user:* with expiration | test/user:*
8+
permission can_view = viewer + public_reader
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import "./dependency.zed"
2+
3+
definition test/user {}
4+
5+
definition test/document {
6+
relation viewer: test/user with expiration | test/user
7+
relation public_reader: test/user:* with expiration | test/user:*
8+
permission can_view = viewer + public_reader
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
use expiration
2+
3+
definition user {
4+
relation timeout: user with expiration
5+
}
6+
7+
definition resource {
8+
relation tempuser: user with expiration
9+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use expiration
2+
3+
import "./user.zed"
4+
5+
definition resource {
6+
relation tempuser: user with expiration
7+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
use expiration
2+
3+
definition user {
4+
relation timeout: user with expiration
5+
}

pkg/composableschemadsl/compiler/importer_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ func TestImporter(t *testing.T) {
6161
{"nested local import with transitive hop", "nested-local-with-hop"},
6262
{"nested local two layers deep import", "nested-two-layer-local"},
6363
{"diamond-shaped imports are fine", "diamond-shaped"},
64+
{"multiple use directives are fine", "multiple-use-directives"},
65+
{"expiration works correctly across multiple files", "expiration-usage"},
6466
}
6567

6668
for _, test := range importerTests {

pkg/composableschemadsl/compiler/translator.go

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"errors"
77
"fmt"
88
"path/filepath"
9-
"slices"
109
"strings"
1110

1211
"github.com/ccoveille/go-safecast/v2"
@@ -27,8 +26,8 @@ type translationContext struct {
2726
mapper input.PositionMapper
2827
schemaString string
2928
skipValidate bool
30-
allowedFlags []string
31-
enabledFlags []string
29+
allowedFlags *mapz.Set[string]
30+
enabledFlags *mapz.Set[string]
3231
existingNames *mapz.Set[string]
3332
caveatTypeSet *caveattypes.TypeSet
3433

@@ -668,29 +667,6 @@ func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNo
668667
return nil, typeRefNode.Errorf("%w", err)
669668
}
670669

671-
if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) {
672-
ref := &core.AllowedRelation{
673-
Namespace: nspath,
674-
RelationOrWildcard: &core.AllowedRelation_PublicWildcard_{
675-
PublicWildcard: &core.AllowedRelation_PublicWildcard{},
676-
},
677-
}
678-
679-
err = addWithCaveats(tctx, typeRefNode, ref)
680-
if err != nil {
681-
return nil, typeRefNode.Errorf("invalid caveat: %w", err)
682-
}
683-
684-
if !tctx.skipValidate {
685-
if err := ref.Validate(); err != nil {
686-
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
687-
}
688-
}
689-
690-
ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
691-
return ref, nil
692-
}
693-
694670
relationName := Ellipsis
695671
if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateRelation) {
696672
relationName, err = typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateRelation)
@@ -706,42 +682,57 @@ func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNo
706682
},
707683
}
708684

685+
if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) {
686+
ref.RelationOrWildcard = &core.AllowedRelation_PublicWildcard_{
687+
PublicWildcard: &core.AllowedRelation_PublicWildcard{},
688+
}
689+
}
690+
709691
// Add the caveat(s), if any.
710692
err = addWithCaveats(tctx, typeRefNode, ref)
711693
if err != nil {
712694
return nil, typeRefNode.Errorf("invalid caveat: %w", err)
713695
}
714696

715697
// Add the expiration trait, if any.
698+
err = addWithExpiration(tctx, typeRefNode, ref)
699+
if err != nil {
700+
return nil, typeRefNode.Errorf("invalid expiration: %w", err)
701+
}
702+
703+
if !tctx.skipValidate {
704+
if err := ref.Validate(); err != nil {
705+
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
706+
}
707+
}
708+
709+
ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
710+
return ref, nil
711+
}
712+
713+
func addWithExpiration(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
716714
if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil {
717715
traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait)
718716
if err != nil {
719-
return nil, typeRefNode.Errorf("invalid trait: %w", err)
717+
return err
720718
}
721719

722720
if traitName != "expiration" {
723-
return nil, typeRefNode.Errorf("invalid trait: %s", traitName)
721+
return fmt.Errorf("invalid trait: %s", traitName)
724722
}
725723

726-
if !slices.Contains(tctx.allowedFlags, "expiration") {
727-
return nil, typeRefNode.Errorf("expiration trait is not allowed")
724+
if !tctx.allowedFlags.Has("expiration") {
725+
return errors.New("expiration trait is not allowed")
728726
}
729727

730-
if !slices.Contains(tctx.enabledFlags, "expiration") {
731-
return nil, typeRefNode.Errorf("expiration flag is not enabled; add `use expiration` to top of file")
728+
if !tctx.enabledFlags.Has("expiration") {
729+
return errors.New("expiration flag is not enabled; add `use expiration` to top of file")
732730
}
733731

734732
ref.RequiredExpiration = &core.ExpirationTrait{}
735733
}
736734

737-
if !tctx.skipValidate {
738-
if err := ref.Validate(); err != nil {
739-
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
740-
}
741-
}
742-
743-
ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
744-
return ref, nil
735+
return nil
745736
}
746737

747738
func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
@@ -946,9 +937,9 @@ func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error {
946937
if err != nil {
947938
return err
948939
}
949-
if slices.Contains(tctx.enabledFlags, flagName) {
950-
return useFlagNode.Errorf("found duplicate use flag: %s", flagName)
951-
}
952-
tctx.enabledFlags = append(tctx.enabledFlags, flagName)
940+
// NOTE: we're okay with multiple instances of a given `use` directive in
941+
// composable schemas, because each file may declare it separately
942+
// and that should be valid.
943+
tctx.enabledFlags.Add(flagName)
953944
return nil
954945
}

0 commit comments

Comments
 (0)