@@ -21,18 +21,18 @@ type PolicyMap map[cedar.PolicyID]*Policy
2121// PolicySet is a set of named policies against which a request can be authorized.
2222type PolicySet struct {
2323 // policies are stored internally so we can handle performance, concurrency bookkeeping however we want
24- policies PolicyMap
24+ staticPolicies PolicyMap
25+ linkedPolicies map [cedar.PolicyID ]* LinkedPolicy
2526
2627 templates map [cedar.PolicyID ]* Template
27- links map [cedar.PolicyID ]* LinkedPolicy
2828}
2929
3030// NewPolicySet creates a new, empty PolicySet
3131func NewPolicySet () * PolicySet {
3232 return & PolicySet {
33- policies : PolicyMap {},
34- templates : make (map [cedar.PolicyID ]* Template ),
35- links : make (map [cedar.PolicyID ]* LinkedPolicy ),
33+ staticPolicies : PolicyMap {},
34+ templates : make (map [cedar.PolicyID ]* Template ),
35+ linkedPolicies : make (map [cedar.PolicyID ]* LinkedPolicy ),
3636 }
3737}
3838
@@ -58,55 +58,58 @@ func NewPolicySetFromBytes(fileName string, document []byte) (*PolicySet, error)
5858 templateMap [policyID ] = p
5959 }
6060
61- return & PolicySet {policies : policyMap , templates : templateMap , links : make (map [cedar.PolicyID ]* LinkedPolicy )}, nil
61+ return & PolicySet {staticPolicies : policyMap , templates : templateMap , linkedPolicies : make (map [cedar.PolicyID ]* LinkedPolicy )}, nil
6262}
6363
6464// Get returns the Policy with the given ID. If a policy with the given ID
6565// does not exist, nil is returned.
6666func (p * PolicySet ) Get (policyID cedar.PolicyID ) * Policy {
67- return p .policies [policyID ]
67+ return p .staticPolicies [policyID ]
6868}
6969
7070// Add inserts or updates a policy with the given ID. Returns true if a policy
7171// with the given ID did not already exist in the set.
7272func (p * PolicySet ) Add (policyID cedar.PolicyID , policy * Policy ) bool {
73- _ , exists := p .policies [policyID ]
74- p .policies [policyID ] = policy
73+ _ , exists := p .staticPolicies [policyID ]
74+ p .staticPolicies [policyID ] = policy
7575 return ! exists
7676}
7777
78- // todo: check to see if it's a static policy or a linked policy
7978// Remove removes a policy from the PolicySet. Returns true if a policy with
8079// the given ID already existed in the set.
8180func (p * PolicySet ) Remove (policyID cedar.PolicyID ) bool {
82- _ , exists := p .policies [policyID ]
83- delete (p .policies , policyID )
84- return exists
81+ _ , staticExists := p .staticPolicies [policyID ]
82+ delete (p .staticPolicies , policyID )
83+
84+ _ , linkExists := p .linkedPolicies [policyID ]
85+ delete (p .linkedPolicies , policyID )
86+
87+ return staticExists || linkExists
8588}
8689
8790// Map returns a new PolicyMap instance of the policies in the PolicySet.
8891//
8992// Deprecated: use the iterator returned by All() like so: maps.Collect(ps.All())
9093func (p * PolicySet ) Map () PolicyMap {
91- return maps .Clone (p .policies )
94+ return maps .Clone (p .staticPolicies )
9295}
9396
9497// MarshalCedar emits a concatenated Cedar representation of a PolicySet. The policy names are stripped, but policies
9598// are emitted in lexicographical order by ID.
9699func (p * PolicySet ) MarshalCedar () []byte {
97- ids := make ([]cedar.PolicyID , 0 , len (p .policies ))
98- for k := range p .policies {
100+ ids := make ([]cedar.PolicyID , 0 , len (p .staticPolicies ))
101+ for k := range p .staticPolicies {
99102 ids = append (ids , k )
100103 }
101104 slices .Sort (ids )
102105
103106 var buf bytes.Buffer
104107 i := 0
105108 for _ , id := range ids {
106- policy := p .policies [id ]
109+ policy := p .staticPolicies [id ]
107110 buf .Write (policy .MarshalCedar ())
108111
109- if i < len (p .policies )- 1 {
112+ if i < len (p .staticPolicies )- 1 {
110113 buf .WriteString ("\n \n " )
111114 }
112115 i ++
@@ -119,9 +122,9 @@ func (p *PolicySet) MarshalCedar() []byte {
119122// [Cedar documentation]: https://docs.cedarpolicy.com/policies/json-format.html
120123func (p * PolicySet ) MarshalJSON () ([]byte , error ) {
121124 jsonPolicySet := internaljson.PolicySetJSON {
122- StaticPolicies : make (internaljson.PolicySet , len (p .policies )),
125+ StaticPolicies : make (internaljson.PolicySet , len (p .staticPolicies )),
123126 }
124- for k , v := range p .policies {
127+ for k , v := range p .staticPolicies {
125128 jsonPolicySet .StaticPolicies [string (k )] = (* internaljson .Policy )(v .AST ())
126129 }
127130 return json .Marshal (jsonPolicySet )
@@ -136,24 +139,24 @@ func (p *PolicySet) UnmarshalJSON(b []byte) error {
136139 return err
137140 }
138141 * p = PolicySet {
139- policies : make (PolicyMap , len (jsonPolicySet .StaticPolicies )),
142+ staticPolicies : make (PolicyMap , len (jsonPolicySet .StaticPolicies )),
140143 }
141144 for k , v := range jsonPolicySet .StaticPolicies {
142- p .policies [cedar .PolicyID (k )] = newPolicy ((* internalast .Policy )(v )) // NewPolicyFromAST((*ast.Policy)(v))
145+ p .staticPolicies [cedar .PolicyID (k )] = newPolicy ((* internalast .Policy )(v )) // NewPolicyFromAST((*ast.Policy)(v))
143146 }
144147 return nil
145148}
146149
147150// All returns an iterator over the (PolicyID, *Policy) tuples in the PolicySet
148151func (p * PolicySet ) All () iter.Seq2 [cedar.PolicyID , * Policy ] {
149152 return func (yield func (cedar.PolicyID , * Policy ) bool ) {
150- for k , v := range p .policies {
153+ for k , v := range p .staticPolicies {
151154 if ! yield (k , v ) {
152155 break
153156 }
154157 }
155158
156- for k , v := range p .links {
159+ for k , v := range p .linkedPolicies {
157160 // Render links on read to make template changes propagate
158161 policy , err := p .render (* v )
159162 if err != nil { //todo: think how to propagate this error
@@ -241,6 +244,11 @@ func (l *LinkedPolicy) LinkID() cedar.PolicyID {
241244//
242245// Returns a LinkedPolicy that can be rendered into a concrete Policy.
243246func (p * PolicySet ) LinkTemplate (templateID cedar.PolicyID , linkID cedar.PolicyID , slotEnv map [types.SlotID ]types.EntityUID ) error {
247+ _ , exists := p .staticPolicies [linkID ]
248+ if exists {
249+ return fmt .Errorf ("link ID %s already exists in the policy set" , linkID )
250+ }
251+
244252 template := p .GetTemplate (templateID )
245253 if template == nil {
246254 return fmt .Errorf ("template %s not found" , templateID )
@@ -257,15 +265,15 @@ func (p *PolicySet) LinkTemplate(templateID cedar.PolicyID, linkID cedar.PolicyI
257265 }
258266
259267 link := LinkedPolicy {templateID , linkID , slotEnv }
260- p .links [linkID ] = & link
268+ p .linkedPolicies [linkID ] = & link
261269
262270 return nil
263271}
264272
265273// GetLinkedPolicy returns the LinkedPolicy associated with the given link ID.
266274// If the linked policy does not exist, it returns nil.
267275func (p * PolicySet ) GetLinkedPolicy (linkID cedar.PolicyID ) * LinkedPolicy {
268- return p .links [linkID ]
276+ return p .linkedPolicies [linkID ]
269277}
270278
271279// GetTemplate returns the Template with the given ID.
@@ -288,9 +296,9 @@ func (p *PolicySet) RemoveTemplate(templateID cedar.PolicyID) bool {
288296 _ , exists := p .templates [templateID ]
289297 if exists {
290298 // Remove all linked policies that reference this template
291- for linkID , link := range p .links {
299+ for linkID , link := range p .linkedPolicies {
292300 if link .templateID == templateID {
293- delete (p .links , linkID )
301+ delete (p .linkedPolicies , linkID )
294302 }
295303 }
296304 }
0 commit comments