Skip to content

Commit 09a9d0a

Browse files
authored
feat(integrations): propagate current org name to policy providers (#1830)
Signed-off-by: Jose I. Paris <[email protected]>
1 parent 8e7f5d5 commit 09a9d0a

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

app/controlplane/internal/service/attestation.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,12 @@ func (s *AttestationService) GetPolicy(ctx context.Context, req *cpAPI.Attestati
404404
return nil, errors.Forbidden("forbidden", "token not found")
405405
}
406406

407-
remotePolicy, err := s.workflowContractUseCase.GetPolicy(req.GetProvider(), req.GetPolicyName(), req.GetOrgName(), token.Token)
407+
org, err := requireCurrentOrg(ctx)
408+
if err != nil {
409+
return nil, errors.Forbidden("forbidden", "organization not found")
410+
}
411+
412+
remotePolicy, err := s.workflowContractUseCase.GetPolicy(req.GetProvider(), req.GetPolicyName(), req.GetOrgName(), org.Name, token.Token)
408413
if err != nil {
409414
return nil, handleUseCaseErr(err, s.log)
410415
}
@@ -421,7 +426,12 @@ func (s *AttestationService) GetPolicyGroup(ctx context.Context, req *cpAPI.Atte
421426
return nil, errors.Forbidden("forbidden", "token not found")
422427
}
423428

424-
remoteGroup, err := s.workflowContractUseCase.GetPolicyGroup(req.GetProvider(), req.GetGroupName(), req.GetOrgName(), token.Token)
429+
org, err := requireCurrentOrg(ctx)
430+
if err != nil {
431+
return nil, errors.Forbidden("forbidden", "organization not found")
432+
}
433+
434+
remoteGroup, err := s.workflowContractUseCase.GetPolicyGroup(req.GetProvider(), req.GetGroupName(), req.GetOrgName(), org.Name, token.Token)
425435
if err != nil {
426436
return nil, handleUseCaseErr(err, s.log)
427437
}

app/controlplane/pkg/biz/workflowcontract.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ func (uc *WorkflowContractUseCase) findAndValidatePolicy(att *schemav1.PolicyAtt
388388
return nil, err
389389
}
390390

391-
remotePolicy, err := uc.GetPolicy(pr.Provider, pr.Name, pr.OrgName, token)
391+
remotePolicy, err := uc.GetPolicy(pr.Provider, pr.Name, pr.OrgName, "", token)
392392
if err != nil {
393393
return nil, err
394394
}
@@ -415,7 +415,7 @@ func (uc *WorkflowContractUseCase) findPolicyGroup(att *schemav1.PolicyGroupAtta
415415
// [chainloop://][provider/]name
416416
if loader.IsProviderScheme(att.GetRef()) {
417417
pr := loader.ProviderParts(att.GetRef())
418-
remoteGroup, err := uc.GetPolicyGroup(pr.Provider, pr.Name, pr.OrgName, token)
418+
remoteGroup, err := uc.GetPolicyGroup(pr.Provider, pr.Name, pr.OrgName, "", token)
419419
if err != nil {
420420
return nil, NewErrValidation(fmt.Errorf("failed to get policy group: %w", err))
421421
}
@@ -492,13 +492,16 @@ type RemotePolicyGroup struct {
492492
}
493493

494494
// GetPolicy retrieves a policy from a policy provider
495-
func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, orgName, token string) (*RemotePolicy, error) {
495+
func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, policyOrgName, currentOrgName, token string) (*RemotePolicy, error) {
496496
provider, err := uc.findProvider(providerName)
497497
if err != nil {
498498
return nil, err
499499
}
500500

501-
policy, ref, err := provider.Resolve(policyName, orgName, token)
501+
policy, ref, err := provider.Resolve(policyName, policyOrgName, policies.ProviderAuthOpts{
502+
Token: token,
503+
OrgName: currentOrgName,
504+
})
502505
if err != nil {
503506
if errors.Is(err, policies.ErrNotFound) {
504507
return nil, NewErrNotFound(fmt.Sprintf("policy %q", policyName))
@@ -510,13 +513,16 @@ func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, orgName,
510513
return &RemotePolicy{Policy: policy, ProviderRef: ref}, nil
511514
}
512515

513-
func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, orgName, token string) (*RemotePolicyGroup, error) {
516+
func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, groupOrgName, currentOrgName, token string) (*RemotePolicyGroup, error) {
514517
provider, err := uc.findProvider(providerName)
515518
if err != nil {
516519
return nil, err
517520
}
518521

519-
group, ref, err := provider.ResolveGroup(groupName, orgName, token)
522+
group, ref, err := provider.ResolveGroup(groupName, groupOrgName, policies.ProviderAuthOpts{
523+
Token: token,
524+
OrgName: currentOrgName,
525+
})
520526
if err != nil {
521527
if errors.Is(err, policies.ErrNotFound) {
522528
return nil, NewErrNotFound(fmt.Sprintf("policy group %q", groupName))

app/controlplane/pkg/policies/policyprovider.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ const (
3535
validateAction = "validate"
3636
groupsEndpoint = "groups"
3737

38-
digestParam = "digest"
39-
orgNameParam = "organization_name"
38+
digestParam = "digest"
39+
orgNameParam = "organization_name"
40+
organizationHeader = "Chainloop-Organization"
4041
)
4142

4243
// PolicyProvider represents an external policy provider
@@ -72,12 +73,17 @@ type PolicyReference struct {
7273
Digest string
7374
}
7475

76+
type ProviderAuthOpts struct {
77+
Token string
78+
OrgName string
79+
}
80+
7581
var ErrNotFound = fmt.Errorf("policy not found")
7682

7783
// Resolve calls the remote provider for retrieving a policy
78-
func (p *PolicyProvider) Resolve(policyName, orgName, token string) (*schemaapi.Policy, *PolicyReference, error) {
79-
if policyName == "" || token == "" {
80-
return nil, nil, fmt.Errorf("both policyname and token are mandatory")
84+
func (p *PolicyProvider) Resolve(policyName, policyOrgName string, authOpts ProviderAuthOpts) (*schemaapi.Policy, *PolicyReference, error) {
85+
if policyName == "" || authOpts.Token == "" {
86+
return nil, nil, fmt.Errorf("both policyname and auth opts are mandatory")
8187
}
8288

8389
// the policy name might include a digest in the form of <name>@sha256:<digest>
@@ -94,7 +100,7 @@ func (p *PolicyProvider) Resolve(policyName, orgName, token string) (*schemaapi.
94100
}
95101
// we want to override the orgName with the one in the response
96102
// since we might have resolved it implicitly
97-
providerDigest, orgName, err := p.queryProvider(url, digest, orgName, token, &policy)
103+
providerDigest, orgName, err := p.queryProvider(url, digest, policyOrgName, authOpts, &policy)
98104
if err != nil {
99105
return nil, nil, fmt.Errorf("failed to resolve policy: %w", err)
100106
}
@@ -170,8 +176,8 @@ func (p *PolicyProvider) ValidateAttachment(att *schemaapi.PolicyAttachment, tok
170176
}
171177

172178
// ResolveGroup calls remote provider for retrieving a policy group definition
173-
func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schemaapi.PolicyGroup, *PolicyReference, error) {
174-
if groupName == "" || token == "" {
179+
func (p *PolicyProvider) ResolveGroup(groupName, groupOrgName string, authOpts ProviderAuthOpts) (*schemaapi.PolicyGroup, *PolicyReference, error) {
180+
if groupName == "" || authOpts.Token == "" {
175181
return nil, nil, fmt.Errorf("both policyname and token are mandatory")
176182
}
177183

@@ -189,7 +195,7 @@ func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schema
189195
}
190196
// we want to override the orgName with the one in the response
191197
// since we might have resolved it implicitly
192-
providerDigest, orgName, err := p.queryProvider(url, digest, orgName, token, &group)
198+
providerDigest, orgName, err := p.queryProvider(url, digest, groupOrgName, authOpts, &group)
193199
if err != nil {
194200
return nil, nil, fmt.Errorf("failed to resolve group: %w", err)
195201
}
@@ -198,7 +204,7 @@ func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schema
198204
}
199205

200206
// returns digest, orgname, error
201-
func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName, token string, out proto.Message) (string, string, error) {
207+
func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, authOpts ProviderAuthOpts, out proto.Message) (string, string, error) {
202208
query := url.Query()
203209
if digest != "" {
204210
query.Set(digestParam, digest)
@@ -215,7 +221,10 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName, token stri
215221
return "", "", fmt.Errorf("error creating policy request: %w", err)
216222
}
217223

218-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
224+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authOpts.Token))
225+
if authOpts.OrgName != "" {
226+
req.Header.Set(organizationHeader, authOpts.OrgName)
227+
}
219228

220229
// make the request
221230
resp, err := http.DefaultClient.Do(req)

0 commit comments

Comments
 (0)