Skip to content

Commit 48199c2

Browse files
authored
Refactor graph creation logic for improved clarity and maintainability (#585)
1 parent 3d45448 commit 48199c2

File tree

3 files changed

+661
-256
lines changed

3 files changed

+661
-256
lines changed

pkg/sync/expand/expander.go

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
package expand
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"fmt"
8+
"os"
9+
"strconv"
10+
11+
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
12+
"go.uber.org/zap"
13+
14+
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
15+
reader_v2 "github.com/conductorone/baton-sdk/pb/c1/reader/v2"
16+
"github.com/conductorone/baton-sdk/pkg/annotations"
17+
)
18+
19+
const defaultMaxDepth int64 = 20
20+
21+
var maxDepth, _ = strconv.ParseInt(os.Getenv("BATON_GRAPH_EXPAND_MAX_DEPTH"), 10, 64)
22+
23+
// ErrMaxDepthExceeded is returned when the expansion graph exceeds the maximum allowed depth.
24+
var ErrMaxDepthExceeded = errors.New("max depth exceeded")
25+
26+
// ExpanderStore defines the minimal store interface needed for grant expansion.
27+
// This interface can be implemented by the connectorstore or by a mock for testing.
28+
type ExpanderStore interface {
29+
GetEntitlement(ctx context.Context, req *reader_v2.EntitlementsReaderServiceGetEntitlementRequest) (*reader_v2.EntitlementsReaderServiceGetEntitlementResponse, error)
30+
ListGrantsForEntitlement(ctx context.Context, req *reader_v2.GrantsReaderServiceListGrantsForEntitlementRequest) (*reader_v2.GrantsReaderServiceListGrantsForEntitlementResponse, error)
31+
PutGrants(ctx context.Context, grants ...*v2.Grant) error
32+
}
33+
34+
// Expander handles the grant expansion algorithm.
35+
// It can be used standalone for testing or called from the syncer.
36+
type Expander struct {
37+
store ExpanderStore
38+
graph *EntitlementGraph
39+
}
40+
41+
// NewExpander creates a new Expander with the given store and graph.
42+
func NewExpander(store ExpanderStore, graph *EntitlementGraph) *Expander {
43+
return &Expander{
44+
store: store,
45+
graph: graph,
46+
}
47+
}
48+
49+
// Graph returns the entitlement graph.
50+
func (e *Expander) Graph() *EntitlementGraph {
51+
return e.graph
52+
}
53+
54+
// Run executes the complete expansion algorithm until the graph is fully expanded.
55+
// This is useful for testing where you want to run the entire expansion in one call.
56+
func (e *Expander) Run(ctx context.Context) error {
57+
for {
58+
err := e.RunSingleStep(ctx)
59+
if err != nil {
60+
return err
61+
}
62+
if e.IsDone(ctx) {
63+
return nil
64+
}
65+
}
66+
}
67+
68+
// RunSingleStep executes one step of the expansion algorithm.
69+
// Returns true when the graph is fully expanded, false if more work is needed.
70+
// This matches the syncer's step-by-step execution model.
71+
func (e *Expander) RunSingleStep(ctx context.Context) error {
72+
l := ctxzap.Extract(ctx)
73+
l = l.With(zap.Int("depth", e.graph.Depth))
74+
l.Debug("expander: starting step")
75+
76+
// Process current action if any
77+
if len(e.graph.Actions) > 0 {
78+
action := e.graph.Actions[0]
79+
nextPageToken, err := e.runAction(ctx, action)
80+
if err != nil {
81+
l.Error("expander: error running graph action", zap.Error(err), zap.Any("action", action))
82+
_ = e.graph.DeleteEdge(ctx, action.SourceEntitlementID, action.DescendantEntitlementID)
83+
if errors.Is(err, sql.ErrNoRows) {
84+
// Skip action and delete the edge that caused the error.
85+
e.graph.Actions = e.graph.Actions[1:]
86+
return nil
87+
}
88+
return err
89+
}
90+
91+
if nextPageToken != "" {
92+
// More pages to process
93+
action.PageToken = nextPageToken
94+
} else {
95+
// Action is complete - mark edge expanded and remove from queue
96+
e.graph.MarkEdgeExpanded(action.SourceEntitlementID, action.DescendantEntitlementID)
97+
e.graph.Actions = e.graph.Actions[1:]
98+
}
99+
}
100+
101+
// If there are still actions remaining, continue processing
102+
if len(e.graph.Actions) > 0 {
103+
return nil
104+
}
105+
106+
// Check max depth
107+
depth := maxDepth
108+
if depth == 0 {
109+
depth = defaultMaxDepth
110+
}
111+
112+
if int64(e.graph.Depth) > depth {
113+
l.Error("expander: exceeded max depth", zap.Int64("max_depth", depth))
114+
return fmt.Errorf("expander: %w (%d)", ErrMaxDepthExceeded, depth)
115+
}
116+
117+
// Generate new actions from expandable entitlements
118+
for sourceEntitlementID := range e.graph.GetExpandableEntitlements(ctx) {
119+
for descendantEntitlementID, grantInfo := range e.graph.GetExpandableDescendantEntitlements(ctx, sourceEntitlementID) {
120+
e.graph.Actions = append(e.graph.Actions, &EntitlementGraphAction{
121+
SourceEntitlementID: sourceEntitlementID,
122+
DescendantEntitlementID: descendantEntitlementID,
123+
PageToken: "",
124+
Shallow: grantInfo.IsShallow,
125+
ResourceTypeIDs: grantInfo.ResourceTypeIDs,
126+
})
127+
}
128+
}
129+
130+
e.graph.Depth++
131+
l.Debug("expander: graph is not expanded, incrementing depth")
132+
return nil
133+
}
134+
135+
func (e *Expander) IsDone(ctx context.Context) bool {
136+
return e.graph.IsExpanded()
137+
}
138+
139+
// runAction processes a single action and returns the next page token.
140+
// If the returned page token is empty, the action is complete.
141+
func (e *Expander) runAction(ctx context.Context, action *EntitlementGraphAction) (string, error) {
142+
l := ctxzap.Extract(ctx)
143+
l = l.With(
144+
zap.Int("depth", e.graph.Depth),
145+
zap.String("source_entitlement_id", action.SourceEntitlementID),
146+
zap.String("descendant_entitlement_id", action.DescendantEntitlementID),
147+
)
148+
149+
// Fetch source and descendant entitlement
150+
sourceEntitlement, err := e.store.GetEntitlement(ctx, reader_v2.EntitlementsReaderServiceGetEntitlementRequest_builder{
151+
EntitlementId: action.SourceEntitlementID,
152+
}.Build())
153+
if err != nil {
154+
l.Error("runAction: error fetching source entitlement", zap.Error(err))
155+
return "", fmt.Errorf("runAction: error fetching source entitlement: %w", err)
156+
}
157+
158+
descendantEntitlement, err := e.store.GetEntitlement(ctx, reader_v2.EntitlementsReaderServiceGetEntitlementRequest_builder{
159+
EntitlementId: action.DescendantEntitlementID,
160+
}.Build())
161+
if err != nil {
162+
l.Error("runAction: error fetching descendant entitlement", zap.Error(err))
163+
return "", fmt.Errorf("runAction: error fetching descendant entitlement: %w", err)
164+
}
165+
166+
// Fetch a page of source grants
167+
sourceGrants, err := e.store.ListGrantsForEntitlement(ctx, reader_v2.GrantsReaderServiceListGrantsForEntitlementRequest_builder{
168+
Entitlement: sourceEntitlement.GetEntitlement(),
169+
PageToken: action.PageToken,
170+
PrincipalResourceTypeIds: action.ResourceTypeIDs,
171+
}.Build())
172+
if err != nil {
173+
l.Error("runAction: error fetching source grants", zap.Error(err))
174+
return "", fmt.Errorf("runAction: error fetching source grants: %w", err)
175+
}
176+
177+
var newGrants = make([]*v2.Grant, 0)
178+
for _, sourceGrant := range sourceGrants.GetList() {
179+
// If this is a shallow action, then we only want to expand grants that have no sources
180+
// which indicates that it was directly assigned.
181+
if action.Shallow {
182+
sourcesMap := sourceGrant.GetSources().GetSources()
183+
// If we have no sources, this is a direct grant
184+
foundDirectGrant := len(sourcesMap) == 0
185+
// If the source grant has sources, then we need to see if any of them are the source entitlement itself
186+
if sourcesMap[action.SourceEntitlementID] != nil {
187+
foundDirectGrant = true
188+
}
189+
190+
// This is not a direct grant, so skip it since we are a shallow action
191+
if !foundDirectGrant {
192+
continue
193+
}
194+
}
195+
196+
// Unroll all grants for the principal on the descendant entitlement.
197+
pageToken := ""
198+
for {
199+
req := reader_v2.GrantsReaderServiceListGrantsForEntitlementRequest_builder{
200+
Entitlement: descendantEntitlement.GetEntitlement(),
201+
PrincipalId: sourceGrant.GetPrincipal().GetId(),
202+
PageToken: pageToken,
203+
Annotations: nil,
204+
}.Build()
205+
206+
resp, err := e.store.ListGrantsForEntitlement(ctx, req)
207+
if err != nil {
208+
l.Error("runAction: error fetching descendant grants", zap.Error(err))
209+
return "", fmt.Errorf("runAction: error fetching descendant grants: %w", err)
210+
}
211+
descendantGrants := resp.GetList()
212+
213+
// If we have no grants for the principal in the descendant entitlement, make one.
214+
if pageToken == "" && resp.GetNextPageToken() == "" && len(descendantGrants) == 0 {
215+
descendantGrant, err := newExpandedGrant(descendantEntitlement.GetEntitlement(), sourceGrant.GetPrincipal(), action.SourceEntitlementID)
216+
if err != nil {
217+
l.Error("runAction: error creating new grant", zap.Error(err))
218+
return "", fmt.Errorf("runAction: error creating new grant: %w", err)
219+
}
220+
newGrants = append(newGrants, descendantGrant)
221+
newGrants, err = e.putGrantsInChunks(ctx, newGrants, 10000)
222+
if err != nil {
223+
l.Error("runAction: error updating descendant grants", zap.Error(err))
224+
return "", fmt.Errorf("runAction: error updating descendant grants: %w", err)
225+
}
226+
break
227+
}
228+
229+
// Add the source entitlement as a source to all descendant grants.
230+
grantsToUpdate := make([]*v2.Grant, 0)
231+
for _, descendantGrant := range descendantGrants {
232+
sourcesMap := descendantGrant.GetSources().GetSources()
233+
if sourcesMap == nil {
234+
sourcesMap = make(map[string]*v2.GrantSources_GrantSource)
235+
}
236+
237+
updated := false
238+
239+
if len(sourcesMap) == 0 {
240+
// If we are already granted this entitlement, make sure to add ourselves as a source.
241+
sourcesMap[action.DescendantEntitlementID] = &v2.GrantSources_GrantSource{}
242+
updated = true
243+
}
244+
// Include the source grant as a source.
245+
if sourcesMap[action.SourceEntitlementID] == nil {
246+
sourcesMap[action.SourceEntitlementID] = &v2.GrantSources_GrantSource{}
247+
updated = true
248+
}
249+
250+
if updated {
251+
sources := v2.GrantSources_builder{Sources: sourcesMap}.Build()
252+
descendantGrant.SetSources(sources)
253+
grantsToUpdate = append(grantsToUpdate, descendantGrant)
254+
}
255+
}
256+
newGrants = append(newGrants, grantsToUpdate...)
257+
258+
newGrants, err = e.putGrantsInChunks(ctx, newGrants, 10000)
259+
if err != nil {
260+
l.Error("runAction: error updating descendant grants", zap.Error(err))
261+
return "", fmt.Errorf("runAction: error updating descendant grants: %w", err)
262+
}
263+
264+
pageToken = resp.GetNextPageToken()
265+
if pageToken == "" {
266+
break
267+
}
268+
}
269+
}
270+
271+
_, err = e.putGrantsInChunks(ctx, newGrants, 0)
272+
if err != nil {
273+
l.Error("runAction: error updating descendant grants", zap.Error(err))
274+
return "", fmt.Errorf("runAction: error updating descendant grants: %w", err)
275+
}
276+
277+
return sourceGrants.GetNextPageToken(), nil
278+
}
279+
280+
// putGrantsInChunks accumulates grants until the buffer exceeds minChunkSize,
281+
// then writes all grants to the store at once.
282+
func (e *Expander) putGrantsInChunks(ctx context.Context, grants []*v2.Grant, minChunkSize int) ([]*v2.Grant, error) {
283+
if len(grants) < minChunkSize {
284+
return grants, nil
285+
}
286+
287+
err := e.store.PutGrants(ctx, grants...)
288+
if err != nil {
289+
return nil, fmt.Errorf("putGrantsInChunks: error putting grants: %w", err)
290+
}
291+
292+
return make([]*v2.Grant, 0), nil
293+
}
294+
295+
// newExpandedGrant creates a new grant for a principal on a descendant entitlement.
296+
func newExpandedGrant(descEntitlement *v2.Entitlement, principal *v2.Resource, sourceEntitlementID string) (*v2.Grant, error) {
297+
enResource := descEntitlement.GetResource()
298+
if enResource == nil {
299+
return nil, fmt.Errorf("newExpandedGrant: entitlement has no resource")
300+
}
301+
302+
if principal == nil {
303+
return nil, fmt.Errorf("newExpandedGrant: principal is nil")
304+
}
305+
306+
// Add immutable annotation since this function is only called if no direct grant exists
307+
var annos annotations.Annotations
308+
annos.Update(&v2.GrantImmutable{})
309+
310+
var sources *v2.GrantSources
311+
if sourceEntitlementID != "" {
312+
sources = &v2.GrantSources{
313+
Sources: map[string]*v2.GrantSources_GrantSource{
314+
sourceEntitlementID: {},
315+
},
316+
}
317+
}
318+
319+
grant := v2.Grant_builder{
320+
Id: fmt.Sprintf("%s:%s:%s", descEntitlement.GetId(), principal.GetId().GetResourceType(), principal.GetId().GetResource()),
321+
Entitlement: descEntitlement,
322+
Principal: principal,
323+
Sources: sources,
324+
Annotations: annos,
325+
}.Build()
326+
327+
return grant, nil
328+
}

0 commit comments

Comments
 (0)