Skip to content

Commit 02ed0e0

Browse files
authored
fix(policies): Fix policy reference (#1449)
Signed-off-by: Jose I. Paris <[email protected]>
1 parent 0748db4 commit 02ed0e0

File tree

8 files changed

+136
-125
lines changed

8 files changed

+136
-125
lines changed

app/controlplane/pkg/policies/policyprovider.go

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,16 @@ func (p *PolicyProvider) Resolve(policyName, orgName, token string) (*schemaapi.
6868
if err != nil {
6969
return nil, nil, fmt.Errorf("failed to resolve policy: %w", err)
7070
}
71-
ref, err := p.queryProvider(endpoint, digest, orgName, token, &policy)
71+
url, err := url.Parse(endpoint)
72+
if err != nil {
73+
return nil, nil, fmt.Errorf("error parsing policy provider URL: %w", err)
74+
}
75+
providerDigest, err := p.queryProvider(url, digest, orgName, token, &policy)
7276
if err != nil {
7377
return nil, nil, fmt.Errorf("failed to resolve policy: %w", err)
7478
}
7579

76-
return &policy, ref, nil
80+
return &policy, createRef(url, policyName, providerDigest, orgName), nil
7781
}
7882

7983
// ResolveGroup calls remote provider for retrieving a policy group definition
@@ -83,29 +87,27 @@ func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schema
8387
}
8488

8589
// the policy name might include a digest in the form of <name>@sha256:<digest>
86-
policyName, digest := policies.ExtractDigest(groupName)
90+
groupName, digest := policies.ExtractDigest(groupName)
8791

8892
var group schemaapi.PolicyGroup
89-
endpoint, err := url.JoinPath(p.url, groupsEndpoint, policyName)
93+
endpoint, err := url.JoinPath(p.url, groupsEndpoint, groupName)
9094
if err != nil {
9195
return nil, nil, fmt.Errorf("failed to resolve group: %w", err)
9296
}
93-
ref, err := p.queryProvider(endpoint, digest, orgName, token, &group)
97+
url, err := url.Parse(endpoint)
98+
if err != nil {
99+
return nil, nil, fmt.Errorf("error parsing policy provider URL: %w", err)
100+
}
101+
providerDigest, err := p.queryProvider(url, digest, orgName, token, &group)
94102
if err != nil {
95103
return nil, nil, fmt.Errorf("failed to resolve group: %w", err)
96104
}
97105

98-
return &group, ref, nil
106+
return &group, createRef(url, groupName, providerDigest, orgName), nil
99107
}
100108

101-
func (p *PolicyProvider) queryProvider(path, digest, orgName, token string, out proto.Message) (*PolicyReference, error) {
102-
// craft the URL
103-
uri, err := url.Parse(path)
104-
if err != nil {
105-
return nil, fmt.Errorf("error parsing policy provider URL: %w", err)
106-
}
107-
108-
query := uri.Query()
109+
func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName, token string, out proto.Message) (string, error) {
110+
query := url.Query()
109111
if digest != "" {
110112
query.Set(digestParam, digest)
111113
}
@@ -114,75 +116,60 @@ func (p *PolicyProvider) queryProvider(path, digest, orgName, token string, out
114116
query.Set(orgNameParam, orgName)
115117
}
116118

117-
uri.RawQuery = query.Encode()
119+
url.RawQuery = query.Encode()
118120

119-
req, err := http.NewRequest("GET", uri.String(), nil)
121+
req, err := http.NewRequest("GET", url.String(), nil)
120122
if err != nil {
121-
return nil, fmt.Errorf("error creating policy request: %w", err)
123+
return "", fmt.Errorf("error creating policy request: %w", err)
122124
}
123125

124126
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
125127

126128
// make the request
127129
resp, err := http.DefaultClient.Do(req)
128130
if err != nil {
129-
return nil, fmt.Errorf("error executing policy request: %w", err)
131+
return "", fmt.Errorf("error executing policy request: %w", err)
130132
}
131133

132134
if resp.StatusCode != http.StatusOK {
133135
if resp.StatusCode == http.StatusNotFound {
134-
return nil, ErrNotFound
136+
return "", ErrNotFound
135137
}
136138

137-
return nil, fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
139+
return "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
138140
}
139141

140142
resBytes, err := io.ReadAll(resp.Body)
141143
if err != nil {
142-
return nil, fmt.Errorf("error reading policy response: %w", err)
144+
return "", fmt.Errorf("error reading policy response: %w", err)
143145
}
144146

145147
// unmarshall response
146148
var response ProviderResponse
147149
if err := json.Unmarshal(resBytes, &response); err != nil {
148-
return nil, fmt.Errorf("error unmarshalling policy response: %w", err)
149-
}
150-
151-
ref, err := p.resolveRef(path, response.Digest)
152-
if err != nil {
153-
return nil, fmt.Errorf("error resolving policy reference: %w", err)
150+
return "", fmt.Errorf("error unmarshalling policy response: %w", err)
154151
}
155152

156153
// extract the policy payload from the query response
157154
jsonPolicy, err := json.Marshal(response.Data)
158155
if err != nil {
159-
return nil, fmt.Errorf("error marshalling policy response: %w", err)
156+
return "", fmt.Errorf("error marshalling policy response: %w", err)
160157
}
161158

162159
if err := protojson.Unmarshal(jsonPolicy, out); err != nil {
163-
return nil, fmt.Errorf("error unmarshalling policy response: %w", err)
160+
return "", fmt.Errorf("error unmarshalling policy response: %w", err)
164161
}
165162

166-
return ref, nil
163+
return response.Digest, nil
167164
}
168165

169-
func (p *PolicyProvider) resolveRef(path, digest string) (*PolicyReference, error) {
170-
// Extract hostname from the policy provider URL
171-
uri, err := url.Parse(p.url)
172-
if err != nil {
173-
return nil, fmt.Errorf("error parsing policy provider URL: %w", err)
174-
}
175-
176-
if uri.Host == "" {
177-
return nil, fmt.Errorf("invalid policy provider URL")
178-
}
179-
180-
if path == "" || digest == "" {
181-
return nil, fmt.Errorf("both path and digest are mandatory")
166+
func createRef(policyURL *url.URL, name, digest, orgName string) *PolicyReference {
167+
refURL := fmt.Sprintf("chainloop://%s/%s", policyURL.Host, name)
168+
if orgName != "" {
169+
refURL = fmt.Sprintf("%s?org=%s", refURL, orgName)
182170
}
183-
184171
return &PolicyReference{
185-
URL: fmt.Sprintf("chainloop://%s/%s", uri.Host, path),
172+
URL: refURL,
186173
Digest: digest,
187-
}, nil
174+
}
188175
}

app/controlplane/pkg/policies/policyprovider_test.go

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,45 @@
1616
package policies
1717

1818
import (
19+
"net/url"
1920
"testing"
2021

2122
"github.com/stretchr/testify/assert"
2223
"github.com/stretchr/testify/require"
2324
)
2425

25-
func TestResolveRef(t *testing.T) {
26+
func TestCreateRef(t *testing.T) {
2627
testCases := []struct {
27-
name string
28-
providerURL string
29-
policyName string
30-
digest string
31-
want *PolicyReference
32-
wantErr bool
28+
name string
29+
policyURL string
30+
policyName string
31+
digest string
32+
orgName string
33+
want *PolicyReference
3334
}{
3435
{
35-
name: "valid",
36-
providerURL: "https://p1host.com/foo",
37-
policyName: "my-policy",
38-
digest: "my-digest",
39-
want: &PolicyReference{URL: "chainloop://p1host.com/my-policy", Digest: "my-digest"},
36+
name: "base",
37+
policyURL: "https://p1host.com/foo",
38+
policyName: "my-policy",
39+
digest: "my-digest",
40+
want: &PolicyReference{URL: "chainloop://p1host.com/my-policy", Digest: "my-digest"},
4041
},
4142
{
42-
name: "missing digest",
43-
providerURL: "https://p1host.com/foo",
44-
policyName: "my-policy",
45-
wantErr: true,
46-
},
47-
{
48-
name: "missing schema",
49-
providerURL: "p1host.com/foo",
50-
policyName: "my-policy",
51-
wantErr: true,
43+
name: "with org",
44+
policyURL: "https://p1host.com/foo",
45+
policyName: "my-policy",
46+
digest: "my-digest",
47+
orgName: "my-org",
48+
want: &PolicyReference{URL: "chainloop://p1host.com/my-policy?org=my-org", Digest: "my-digest"},
5249
},
5350
}
5451

5552
for _, tc := range testCases {
56-
t.Run(tc.providerURL, func(t *testing.T) {
57-
provider := &PolicyProvider{url: tc.providerURL}
58-
59-
got, err := provider.resolveRef(tc.policyName, tc.digest)
60-
if tc.wantErr {
61-
require.Error(t, err)
62-
return
63-
}
64-
53+
t.Run(tc.name, func(t *testing.T) {
54+
policyURL, err := url.Parse(tc.policyURL)
6555
require.NoError(t, err)
56+
got := createRef(policyURL, tc.policyName, tc.digest, tc.orgName)
57+
6658
assert.Equal(t, tc.want, got)
6759
})
6860
}

pkg/policies/group_loader.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,18 @@ import (
2626

2727
pb "github.com/chainloop-dev/chainloop/app/controlplane/api/controlplane/v1"
2828
v1 "github.com/chainloop-dev/chainloop/app/controlplane/api/workflowcontract/v1"
29-
v12 "github.com/chainloop-dev/chainloop/pkg/attestation/crafter/api/attestation/v1"
3029
crv1 "github.com/google/go-containerregistry/pkg/v1"
3130
)
3231

3332
// GroupLoader defines the interface for policy loaders from contract attachments
3433
type GroupLoader interface {
35-
Load(context.Context, *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error)
34+
Load(context.Context, *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error)
3635
}
3736

3837
// FileGroupLoader loader loads policies from filesystem and HTTPS references using Cosign's blob package
3938
type FileGroupLoader struct{}
4039

41-
func (l *FileGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error) {
40+
func (l *FileGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error) {
4241
var (
4342
raw []byte
4443
err error
@@ -68,7 +67,7 @@ func (l *FileGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAtta
6867
// HTTPSGroupLoader loader loads policies from HTTP or HTTPS references
6968
type HTTPSGroupLoader struct{}
7069

71-
func (l *HTTPSGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error) {
70+
func (l *HTTPSGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error) {
7271
ref, wantDigest := ExtractDigest(attachment.GetRef())
7372

7473
// and do not remove the scheme since we need http(s):// to make the request
@@ -105,7 +104,7 @@ type ChainloopGroupLoader struct {
105104

106105
type groupWithReference struct {
107106
group *v1.PolicyGroup
108-
reference *v12.ResourceDescriptor
107+
reference *PolicyDescriptor
109108
}
110109

111110
var remoteGroupCache = make(map[string]*groupWithReference)
@@ -114,7 +113,7 @@ func NewChainloopGroupLoader(client pb.AttestationServiceClient) *ChainloopGroup
114113
return &ChainloopGroupLoader{Client: client}
115114
}
116115

117-
func (c *ChainloopGroupLoader) Load(ctx context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error) {
116+
func (c *ChainloopGroupLoader) Load(ctx context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error) {
118117
ref := attachment.GetRef()
119118

120119
c.cacheMutex.Lock()
@@ -144,7 +143,7 @@ func (c *ChainloopGroupLoader) Load(ctx context.Context, attachment *v1.PolicyGr
144143
return nil, nil, fmt.Errorf("parsing digest: %w", err)
145144
}
146145

147-
reference := policyReferenceResourceDescriptor(resp.Reference.GetUrl(), h)
146+
reference := policyReferenceResourceDescriptor(providerRef.Name, resp.Reference.GetUrl(), providerRef.OrgName, h)
148147
// cache result
149148
remoteGroupCache[ref] = &groupWithReference{group: resp.GetGroup(), reference: reference}
150149
return resp.GetGroup(), reference, nil

0 commit comments

Comments
 (0)