Skip to content

Commit 625cefa

Browse files
authored
feat: CAS readiness probe support (#55)
Signed-off-by: Miguel Martinez Trivino <[email protected]>
1 parent 08cd26f commit 625cefa

File tree

16 files changed

+308
-51
lines changed

16 files changed

+308
-51
lines changed

app/artifact-cas/internal/server/grpc.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"context"
2020
"fmt"
2121
"os"
22+
"regexp"
2223

2324
v1 "github.com/chainloop-dev/chainloop/app/artifact-cas/api/cas/v1"
2425
"github.com/chainloop-dev/chainloop/app/artifact-cas/internal/conf"
@@ -27,6 +28,7 @@ import (
2728
"github.com/getsentry/sentry-go"
2829
"github.com/go-kratos/kratos/v2/errors"
2930
jwtMiddleware "github.com/go-kratos/kratos/v2/middleware/auth/jwt"
31+
"github.com/go-kratos/kratos/v2/middleware/selector"
3032
jwt "github.com/golang-jwt/jwt/v4"
3133
"google.golang.org/genproto/googleapis/bytestream"
3234

@@ -62,11 +64,14 @@ func NewGRPCServer(c *conf.Server, authConf *conf.Auth, byteService *service.Byt
6264
),
6365
logging.Server(logger),
6466
// NOTE: JWT middleware only works for unary requests
65-
// below you can see a reimplementation of the middleware as a stream interceptor
66-
jwtMiddleware.Server(
67-
loadPublicKey(rawKey),
68-
jwtMiddleware.WithSigningMethod(casJWT.SigningMethod),
69-
jwtMiddleware.WithClaims(func() jwt.Claims { return &casJWT.Claims{} })),
67+
// below you can see a re-implementation of the middleware as a stream interceptor
68+
// If we require a logged in user we
69+
selector.Server(
70+
jwtMiddleware.Server(
71+
loadPublicKey(rawKey),
72+
jwtMiddleware.WithSigningMethod(casJWT.SigningMethod),
73+
jwtMiddleware.WithClaims(func() jwt.Claims { return &casJWT.Claims{} })),
74+
).Match(requireAuthentication()).Build(),
7075
validate.Validator(),
7176
),
7277

@@ -92,13 +97,24 @@ func NewGRPCServer(c *conf.Server, authConf *conf.Auth, byteService *service.Byt
9297

9398
bytestream.RegisterByteStreamServer(srv.Server, byteService)
9499
v1.RegisterResourceServiceServer(srv.Server, rSvc)
100+
v1.RegisterStatusServiceServer(srv.Server, service.NewStatusService(Version))
95101

96102
// Register and set metrics to 0
97103
grpc_prometheus.Register(srv.Server)
98104

99105
return srv, nil
100106
}
101107

108+
func requireAuthentication() selector.MatchFunc {
109+
// Skip authentication on the status grpc service
110+
const skipRegexp = "(cas.v1.StatusService/.*)"
111+
112+
return func(ctx context.Context, operation string) bool {
113+
r := regexp.MustCompile(skipRegexp)
114+
return !r.MatchString(operation)
115+
}
116+
}
117+
102118
// load key for verification
103119
func loadPublicKey(rawKey []byte) jwt.Keyfunc {
104120
return func(token *jwt.Token) (interface{}, error) {

app/artifact-cas/internal/server/grpc_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ func TestJWTAuthFunc(t *testing.T) {
138138
}
139139
}
140140

141+
func TestRequireAuthentication(t *testing.T) {
142+
testCases := []struct {
143+
operation string
144+
matches bool
145+
}{
146+
{"/cas.v1.Resource/List", true},
147+
{"/cas.v1.Bytestream/List", true},
148+
{"/cas.v1.StatusService/Infoz", false},
149+
{"/cas.v1.StatusService/Statusz", false},
150+
}
151+
152+
matchFunc := requireAuthentication()
153+
for _, op := range testCases {
154+
assert.Equal(t, matchFunc(context.Background(), op.operation), op.matches)
155+
}
156+
}
157+
141158
func loadTestPublicKey(path string) jwt.Keyfunc {
142159
rawKey, _ := os.ReadFile(path)
143160
return func(token *jwt.Token) (interface{}, error) {

app/controlplane/cmd/wire.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func wireApp(*conf.Bootstrap, credentials.ReaderWriter, log.Logger) (*app, func(
4747
wire.Bind(new(biz.CASClient), new(*biz.CASClientUseCase)),
4848
oci.NewBackendProvider,
4949
serviceOpts,
50+
wire.Value([]biz.CASClientOpts{}),
5051
wire.FieldsOf(new(*conf.Bootstrap), "Server", "Auth", "Data", "CasServer"),
5152
newApp,
5253
),

app/controlplane/cmd/wire_gen.go

Lines changed: 26 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

app/controlplane/internal/biz/casclient.go

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package biz
1717

1818
import (
1919
"context"
20+
"errors"
2021
"fmt"
2122
"io"
2223

@@ -29,9 +30,13 @@ import (
2930
)
3031

3132
type CASClientUseCase struct {
33+
// to generate temporary credentials
3234
credsProvider *CASCredentialsUseCase
35+
// configuration to generate the client
3336
casServerConf *conf.Bootstrap_CASServer
34-
logger *log.Helper
37+
// factory to generate the client
38+
casClientFactory CASClientFactory
39+
logger *log.Helper
3540
}
3641

3742
type CASUploader interface {
@@ -45,11 +50,40 @@ type CASDownloader interface {
4550
type CASClient interface {
4651
CASUploader
4752
CASDownloader
48-
Configured() bool
4953
}
5054

51-
func NewCASClientUseCase(credsProvider *CASCredentialsUseCase, config *conf.Bootstrap_CASServer, l log.Logger) *CASClientUseCase {
52-
return &CASClientUseCase{credsProvider, config, servicelogger.ScopedHelper(l, "biz/cas-client")}
55+
type CASClientFactory func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error)
56+
type CASClientOpts func(u *CASClientUseCase)
57+
58+
func WithClientFactory(f CASClientFactory) CASClientOpts {
59+
return func(c *CASClientUseCase) {
60+
c.casClientFactory = f
61+
}
62+
}
63+
64+
func NewCASClientUseCase(credsProvider *CASCredentialsUseCase, config *conf.Bootstrap_CASServer, l log.Logger, opts ...CASClientOpts) *CASClientUseCase {
65+
// generate a client from the given configuration
66+
defaultCasClientFactory := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) {
67+
conn, err := grpcconn.New(conf.GetGrpc().GetAddr(), token, conf.GetInsecure())
68+
if err != nil {
69+
return nil, fmt.Errorf("failed to create grpc connection: %w", err)
70+
}
71+
72+
return casclient.New(conn), nil
73+
}
74+
75+
uc := &CASClientUseCase{
76+
credsProvider: credsProvider,
77+
casServerConf: config,
78+
logger: servicelogger.ScopedHelper(l, "biz/cas-client"),
79+
casClientFactory: defaultCasClientFactory,
80+
}
81+
82+
for _, opt := range opts {
83+
opt(uc)
84+
}
85+
86+
return uc
5387
}
5488

5589
// The secretID is embedded in the JWT token and is used to identify the secret by the CAS server
@@ -90,35 +124,31 @@ func (uc *CASClientUseCase) Download(ctx context.Context, secretID string, w io.
90124
}
91125

92126
// create a client with a temporary set of credentials for a specific operation
93-
func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (*casclient.Client, error) {
127+
func (uc *CASClientUseCase) casAPIClient(secretID string, role casJWT.Role) (casclient.DownloaderUploader, error) {
94128
token, err := uc.credsProvider.GenerateTemporaryCredentials(secretID, role)
95129
if err != nil {
96130
return nil, fmt.Errorf("failed to generate temporary credentials: %w", err)
97131
}
98132

99133
// Initialize connection to CAS server
100-
return casClient(uc.casServerConf, token)
134+
return uc.casClientFactory(uc.casServerConf, token)
101135
}
102136

103-
func casClient(conf *conf.Bootstrap_CASServer, token string) (*casclient.Client, error) {
104-
conn, err := grpcconn.New(conf.GetGrpc().GetAddr(), token, conf.GetInsecure())
105-
if err != nil {
106-
return nil, fmt.Errorf("failed to create grpc connection: %w", err)
107-
}
108-
109-
return casclient.New(conn), nil
110-
}
111-
112-
// If the CAS client configuration is present and valid
113-
func (uc *CASClientUseCase) Configured() bool {
137+
// If the CAS server can be reached and reports readiness
138+
func (uc *CASClientUseCase) IsReady(ctx context.Context) (bool, error) {
114139
if uc.casServerConf == nil {
115-
return false
140+
return false, errors.New("missing CAS server configuration")
116141
}
117142

118143
err := uc.casServerConf.ValidateAll()
119144
if err != nil {
120-
uc.logger.Infow("msg", "Invalid CAS client configuration", "err", err.Error())
145+
return false, fmt.Errorf("invalid CAS client configuration: %w", err)
146+
}
147+
148+
c, err := uc.casClientFactory(uc.casServerConf, "")
149+
if err != nil {
150+
return false, fmt.Errorf("failed to create CAS client: %w", err)
121151
}
122152

123-
return err == nil
153+
return c.IsReady(ctx)
124154
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//
2+
// Copyright 2023 The Chainloop Authors.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package biz_test
17+
18+
import (
19+
"context"
20+
"testing"
21+
22+
"github.com/chainloop-dev/chainloop/app/controlplane/internal/biz"
23+
"github.com/chainloop-dev/chainloop/app/controlplane/internal/conf"
24+
"github.com/chainloop-dev/chainloop/internal/casclient"
25+
"github.com/chainloop-dev/chainloop/internal/casclient/mocks"
26+
"github.com/stretchr/testify/assert"
27+
"github.com/stretchr/testify/mock"
28+
)
29+
30+
func TestIsReady(t *testing.T) {
31+
validConf := &conf.Bootstrap_CASServer{
32+
Grpc: &conf.Server_GRPC{Addr: "localhost:1111"},
33+
}
34+
35+
testCases := []struct {
36+
name string
37+
config *conf.Bootstrap_CASServer
38+
casReady bool
39+
want bool
40+
wantErr bool
41+
}{
42+
{
43+
name: "missing configuration",
44+
config: &conf.Bootstrap_CASServer{},
45+
wantErr: true,
46+
},
47+
{
48+
name: "invalid configuration",
49+
config: &conf.Bootstrap_CASServer{Grpc: &conf.Server_GRPC{}},
50+
wantErr: true,
51+
},
52+
{
53+
name: "not ready configuration",
54+
config: validConf,
55+
wantErr: false,
56+
},
57+
{
58+
name: "ready configuration",
59+
config: validConf,
60+
casReady: true,
61+
want: true,
62+
wantErr: false,
63+
},
64+
}
65+
66+
for _, tc := range testCases {
67+
t.Run(tc.name, func(t *testing.T) {
68+
clientProvider := func(conf *conf.Bootstrap_CASServer, token string) (casclient.DownloaderUploader, error) {
69+
c := mocks.NewDownloaderUploader(t)
70+
c.On("IsReady", mock.Anything).Return(tc.casReady, nil)
71+
return c, nil
72+
}
73+
uc := biz.NewCASClientUseCase(nil, tc.config, nil, biz.WithClientFactory(clientProvider))
74+
75+
got, err := uc.IsReady(context.Background())
76+
if tc.wantErr {
77+
assert.Error(t, err)
78+
} else {
79+
assert.NoError(t, err)
80+
}
81+
assert.Equal(t, tc.want, got)
82+
})
83+
}
84+
}

0 commit comments

Comments
 (0)