Skip to content

Commit 39bdf0e

Browse files
feat: add support to remove links and ensure that link ids are unique
Signed-off-by: Caio Ferreira <caiorcferreira@gmail.com>
1 parent 6eccb26 commit 39bdf0e

File tree

2 files changed

+95
-28
lines changed

2 files changed

+95
-28
lines changed

x/exp/templates/policy_set.go

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ type PolicyMap map[cedar.PolicyID]*Policy
2121
// PolicySet is a set of named policies against which a request can be authorized.
2222
type PolicySet struct {
2323
// policies are stored internally so we can handle performance, concurrency bookkeeping however we want
24-
policies PolicyMap
24+
staticPolicies PolicyMap
25+
linkedPolicies map[cedar.PolicyID]*LinkedPolicy
2526

2627
templates map[cedar.PolicyID]*Template
27-
links map[cedar.PolicyID]*LinkedPolicy
2828
}
2929

3030
// NewPolicySet creates a new, empty PolicySet
3131
func NewPolicySet() *PolicySet {
3232
return &PolicySet{
33-
policies: PolicyMap{},
34-
templates: make(map[cedar.PolicyID]*Template),
35-
links: make(map[cedar.PolicyID]*LinkedPolicy),
33+
staticPolicies: PolicyMap{},
34+
templates: make(map[cedar.PolicyID]*Template),
35+
linkedPolicies: make(map[cedar.PolicyID]*LinkedPolicy),
3636
}
3737
}
3838

@@ -58,55 +58,58 @@ func NewPolicySetFromBytes(fileName string, document []byte) (*PolicySet, error)
5858
templateMap[policyID] = p
5959
}
6060

61-
return &PolicySet{policies: policyMap, templates: templateMap, links: make(map[cedar.PolicyID]*LinkedPolicy)}, nil
61+
return &PolicySet{staticPolicies: policyMap, templates: templateMap, linkedPolicies: make(map[cedar.PolicyID]*LinkedPolicy)}, nil
6262
}
6363

6464
// Get returns the Policy with the given ID. If a policy with the given ID
6565
// does not exist, nil is returned.
6666
func (p *PolicySet) Get(policyID cedar.PolicyID) *Policy {
67-
return p.policies[policyID]
67+
return p.staticPolicies[policyID]
6868
}
6969

7070
// Add inserts or updates a policy with the given ID. Returns true if a policy
7171
// with the given ID did not already exist in the set.
7272
func (p *PolicySet) Add(policyID cedar.PolicyID, policy *Policy) bool {
73-
_, exists := p.policies[policyID]
74-
p.policies[policyID] = policy
73+
_, exists := p.staticPolicies[policyID]
74+
p.staticPolicies[policyID] = policy
7575
return !exists
7676
}
7777

78-
// todo: check to see if it's a static policy or a linked policy
7978
// Remove removes a policy from the PolicySet. Returns true if a policy with
8079
// the given ID already existed in the set.
8180
func (p *PolicySet) Remove(policyID cedar.PolicyID) bool {
82-
_, exists := p.policies[policyID]
83-
delete(p.policies, policyID)
84-
return exists
81+
_, staticExists := p.staticPolicies[policyID]
82+
delete(p.staticPolicies, policyID)
83+
84+
_, linkExists := p.linkedPolicies[policyID]
85+
delete(p.linkedPolicies, policyID)
86+
87+
return staticExists || linkExists
8588
}
8689

8790
// Map returns a new PolicyMap instance of the policies in the PolicySet.
8891
//
8992
// Deprecated: use the iterator returned by All() like so: maps.Collect(ps.All())
9093
func (p *PolicySet) Map() PolicyMap {
91-
return maps.Clone(p.policies)
94+
return maps.Clone(p.staticPolicies)
9295
}
9396

9497
// MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies
9598
// are emitted in lexicographical order by ID.
9699
func (p *PolicySet) MarshalCedar() []byte {
97-
ids := make([]cedar.PolicyID, 0, len(p.policies))
98-
for k := range p.policies {
100+
ids := make([]cedar.PolicyID, 0, len(p.staticPolicies))
101+
for k := range p.staticPolicies {
99102
ids = append(ids, k)
100103
}
101104
slices.Sort(ids)
102105

103106
var buf bytes.Buffer
104107
i := 0
105108
for _, id := range ids {
106-
policy := p.policies[id]
109+
policy := p.staticPolicies[id]
107110
buf.Write(policy.MarshalCedar())
108111

109-
if i < len(p.policies)-1 {
112+
if i < len(p.staticPolicies)-1 {
110113
buf.WriteString("\n\n")
111114
}
112115
i++
@@ -119,9 +122,9 @@ func (p *PolicySet) MarshalCedar() []byte {
119122
// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html
120123
func (p *PolicySet) MarshalJSON() ([]byte, error) {
121124
jsonPolicySet := internaljson.PolicySetJSON{
122-
StaticPolicies: make(internaljson.PolicySet, len(p.policies)),
125+
StaticPolicies: make(internaljson.PolicySet, len(p.staticPolicies)),
123126
}
124-
for k, v := range p.policies {
127+
for k, v := range p.staticPolicies {
125128
jsonPolicySet.StaticPolicies[string(k)] = (*internaljson.Policy)(v.AST())
126129
}
127130
return json.Marshal(jsonPolicySet)
@@ -136,24 +139,24 @@ func (p *PolicySet) UnmarshalJSON(b []byte) error {
136139
return err
137140
}
138141
*p = PolicySet{
139-
policies: make(PolicyMap, len(jsonPolicySet.StaticPolicies)),
142+
staticPolicies: make(PolicyMap, len(jsonPolicySet.StaticPolicies)),
140143
}
141144
for k, v := range jsonPolicySet.StaticPolicies {
142-
p.policies[cedar.PolicyID(k)] = newPolicy((*internalast.Policy)(v)) // NewPolicyFromAST((*ast.Policy)(v))
145+
p.staticPolicies[cedar.PolicyID(k)] = newPolicy((*internalast.Policy)(v)) // NewPolicyFromAST((*ast.Policy)(v))
143146
}
144147
return nil
145148
}
146149

147150
// All returns an iterator over the (PolicyID, *Policy) tuples in the PolicySet
148151
func (p *PolicySet) All() iter.Seq2[cedar.PolicyID, *Policy] {
149152
return func(yield func(cedar.PolicyID, *Policy) bool) {
150-
for k, v := range p.policies {
153+
for k, v := range p.staticPolicies {
151154
if !yield(k, v) {
152155
break
153156
}
154157
}
155158

156-
for k, v := range p.links {
159+
for k, v := range p.linkedPolicies {
157160
// Render links on read to make template changes propagate
158161
policy, err := p.render(*v)
159162
if err != nil { //todo: think how to propagate this error
@@ -241,6 +244,11 @@ func (l *LinkedPolicy) LinkID() cedar.PolicyID {
241244
//
242245
// Returns a LinkedPolicy that can be rendered into a concrete Policy.
243246
func (p *PolicySet) LinkTemplate(templateID cedar.PolicyID, linkID cedar.PolicyID, slotEnv map[types.SlotID]types.EntityUID) error {
247+
_, exists := p.staticPolicies[linkID]
248+
if exists {
249+
return fmt.Errorf("link ID %s already exists in the policy set", linkID)
250+
}
251+
244252
template := p.GetTemplate(templateID)
245253
if template == nil {
246254
return fmt.Errorf("template %s not found", templateID)
@@ -257,15 +265,15 @@ func (p *PolicySet) LinkTemplate(templateID cedar.PolicyID, linkID cedar.PolicyI
257265
}
258266

259267
link := LinkedPolicy{templateID, linkID, slotEnv}
260-
p.links[linkID] = &link
268+
p.linkedPolicies[linkID] = &link
261269

262270
return nil
263271
}
264272

265273
// GetLinkedPolicy returns the LinkedPolicy associated with the given link ID.
266274
// If the linked policy does not exist, it returns nil.
267275
func (p *PolicySet) GetLinkedPolicy(linkID cedar.PolicyID) *LinkedPolicy {
268-
return p.links[linkID]
276+
return p.linkedPolicies[linkID]
269277
}
270278

271279
// GetTemplate returns the Template with the given ID.
@@ -288,9 +296,9 @@ func (p *PolicySet) RemoveTemplate(templateID cedar.PolicyID) bool {
288296
_, exists := p.templates[templateID]
289297
if exists {
290298
// Remove all linked policies that reference this template
291-
for linkID, link := range p.links {
299+
for linkID, link := range p.linkedPolicies {
292300
if link.templateID == templateID {
293-
delete(p.links, linkID)
301+
delete(p.linkedPolicies, linkID)
294302
}
295303
}
296304
}

x/exp/templates/template_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,33 @@ permit (
7171
testutil.Equals(t, isNew, false)
7272
})
7373

74+
t.Run("cannot use link id already used by static policy", func(t *testing.T) {
75+
templateString := `permit (
76+
principal == ?principal,
77+
action,
78+
resource
79+
);
80+
81+
permit (
82+
principal,
83+
action,
84+
resource
85+
);`
86+
templateID := cedar.PolicyID("template0")
87+
policyID := cedar.PolicyID("policy0")
88+
89+
policySet, err := templates.NewPolicySetFromBytes("test.cedar", []byte(templateString))
90+
testutil.OK(t, err)
91+
92+
// Link a policy to the template
93+
//linkID := cedar.PolicyID("linked_policy_id")
94+
env := map[types.SlotID]types.EntityUID{
95+
"?principal": types.NewEntityUID("User", "alice"),
96+
}
97+
err = policySet.LinkTemplate(templateID, policyID, env)
98+
testutil.Error(t, err)
99+
})
100+
74101
t.Run("removing template removes linked policies", func(t *testing.T) {
75102
templateString := `permit (
76103
principal == ?principal,
@@ -102,6 +129,38 @@ permit (
102129
linkedPolicyAfterRemoval := policySet.GetLinkedPolicy(linkID)
103130
testutil.Equals(t, linkedPolicyAfterRemoval == nil, true)
104131
})
132+
133+
t.Run("remove method can also remove linked policy", func(t *testing.T) {
134+
templateString := `permit (
135+
principal == ?principal,
136+
action,
137+
resource
138+
);`
139+
templateID := cedar.PolicyID("template0")
140+
141+
policySet, err := templates.NewPolicySetFromBytes("test.cedar", []byte(templateString))
142+
testutil.OK(t, err)
143+
144+
// Link a policy to the template
145+
linkID := cedar.PolicyID("linked_policy_id")
146+
env := map[types.SlotID]types.EntityUID{
147+
"?principal": types.NewEntityUID("User", "alice"),
148+
}
149+
err = policySet.LinkTemplate(templateID, linkID, env)
150+
testutil.OK(t, err)
151+
152+
// Ensure the linked policy exists
153+
linkedPolicy := policySet.GetLinkedPolicy(linkID)
154+
testutil.Equals(t, linkedPolicy != nil, true)
155+
156+
// Remove the template
157+
removed := policySet.Remove(linkID)
158+
testutil.Equals(t, removed, true)
159+
160+
// The linked policy should also be removed
161+
linkedPolicyAfterRemoval := policySet.GetLinkedPolicy(linkID)
162+
testutil.Equals(t, linkedPolicyAfterRemoval == nil, true)
163+
})
105164
}
106165

107166
func TestLinkTemplateToPolicy(t *testing.T) {

0 commit comments

Comments
 (0)