Skip to content

Commit c02760b

Browse files
committed
Add client test for namespace interceptor.
1 parent 49f1151 commit c02760b

File tree

2 files changed

+275
-44
lines changed

2 files changed

+275
-44
lines changed

pkg/client/client.go

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import (
99
"log/slog"
1010
"os"
1111

12-
"github.com/metal-stack/masterdata-api/pkg/auth"
1312
"google.golang.org/grpc"
1413
"google.golang.org/grpc/credentials"
14+
grpcinsecure "google.golang.org/grpc/credentials/insecure"
1515

1616
v1 "github.com/metal-stack/masterdata-api/api/v1"
17+
"github.com/metal-stack/masterdata-api/pkg/auth"
1718
)
1819

1920
// Client defines the client API
@@ -35,7 +36,6 @@ type GRPCClient struct {
3536

3637
// NewClient creates a new client for the services for the given address, with the certificate and hmac.
3738
func NewClient(ctx context.Context, hostname string, port int, certFile string, keyFile string, caFile string, hmacKey string, insecure bool, logger *slog.Logger, namespace string) (Client, error) {
38-
3939
address := fmt.Sprintf("%s:%d", hostname, port)
4040

4141
certPool, err := x509.SystemCertPool()
@@ -55,20 +55,51 @@ func NewClient(ctx context.Context, hostname string, port int, certFile string,
5555
}
5656
}
5757

58-
clientCertificate, err := tls.LoadX509KeyPair(certFile, keyFile)
59-
if err != nil {
60-
return nil, fmt.Errorf("could not load client key pair: %w", err)
58+
var (
59+
certificates []tls.Certificate
60+
opts []grpc.DialOption
61+
)
62+
63+
if certFile != "" && keyFile != "" {
64+
clientCertificate, err := tls.LoadX509KeyPair(certFile, keyFile)
65+
if err != nil {
66+
return nil, fmt.Errorf("could not load client key pair: %w", err)
67+
}
68+
69+
certificates = append(certificates, clientCertificate)
70+
71+
creds := credentials.NewTLS(&tls.Config{
72+
ServerName: hostname,
73+
Certificates: certificates,
74+
RootCAs: certPool,
75+
MinVersion: tls.VersionTLS12,
76+
InsecureSkipVerify: insecure, // nolint:gosec
77+
})
78+
79+
opts = append(opts,
80+
// oauth.NewOauthAccess requires the configuration of transport
81+
// credentials.
82+
grpc.WithTransportCredentials(creds),
83+
)
84+
} else {
85+
opts = append(opts,
86+
grpc.WithTransportCredentials(grpcinsecure.NewCredentials()),
87+
)
6188
}
6289

63-
creds := credentials.NewTLS(&tls.Config{
64-
ServerName: hostname,
65-
Certificates: []tls.Certificate{clientCertificate},
66-
RootCAs: certPool,
67-
MinVersion: tls.VersionTLS12,
68-
InsecureSkipVerify: insecure, // nolint:gosec
69-
})
90+
if hmacKey != "" {
91+
// Set up the credentials for the connection.
92+
perRPCHMACAuthenticator, err := auth.NewHMACAuther(hmacKey, auth.EditUser)
93+
if err != nil {
94+
return nil, fmt.Errorf("failed to create hmac-authenticator: %w", err)
95+
}
7096

71-
if hmacKey == "" {
97+
opts = append(opts,
98+
// In addition to the following grpc.DialOption, callers may also use
99+
// the grpc.CallOption grpc.PerRPCCredentials with the RPC invocation
100+
// itself.
101+
// See: https://godoc.org/google.golang.org/grpc#PerRPCCredentials
102+
grpc.WithPerRPCCredentials(perRPCHMACAuthenticator))
72103
return nil, errors.New("no hmac-key specified")
73104
}
74105

@@ -77,45 +108,40 @@ func NewClient(ctx context.Context, hostname string, port int, certFile string,
77108
hmacKey: hmacKey,
78109
}
79110

80-
// Set up the credentials for the connection.
81-
perRPCHMACAuthenticator, err := auth.NewHMACAuther(hmacKey, auth.EditUser)
82-
if err != nil {
83-
return nil, fmt.Errorf("failed to create hmac-authenticator: %w", err)
84-
}
85-
86-
namespaceInterceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
87-
switch r := req.(type) {
88-
case *v1.TenantMemberCreateRequest:
89-
r.TenantMember.Namespace = namespace
90-
case *v1.TenantMemberFindRequest:
91-
r.Namespace = namespace
92-
case *v1.ProjectMemberCreateRequest:
93-
r.ProjectMember.Namespace = namespace
94-
case *v1.ProjectMemberFindRequest:
95-
r.Namespace = namespace
111+
if namespace != "" {
112+
namespaceInterceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
113+
switch r := req.(type) {
114+
case *v1.TenantMemberCreateRequest:
115+
if r.TenantMember.Namespace == "" {
116+
r.TenantMember.Namespace = namespace
117+
}
118+
case *v1.TenantMemberFindRequest:
119+
if r.Namespace == "" {
120+
r.Namespace = namespace
121+
}
122+
case *v1.ProjectMemberCreateRequest:
123+
if r.ProjectMember.Namespace == "" {
124+
r.ProjectMember.Namespace = namespace
125+
}
126+
case *v1.ProjectMemberFindRequest:
127+
if r.Namespace == "" {
128+
r.Namespace = namespace
129+
}
130+
}
131+
return invoker(ctx, method, req, reply, cc, opts...)
96132
}
97-
return invoker(ctx, method, req, reply, cc, opts...)
98-
}
99133

100-
opts := []grpc.DialOption{
101-
// In addition to the following grpc.DialOption, callers may also use
102-
// the grpc.CallOption grpc.PerRPCCredentials with the RPC invocation
103-
// itself.
104-
// See: https://godoc.org/google.golang.org/grpc#PerRPCCredentials
105-
grpc.WithPerRPCCredentials(perRPCHMACAuthenticator),
106-
// oauth.NewOauthAccess requires the configuration of transport
107-
// credentials.
108-
grpc.WithTransportCredentials(creds),
109-
110-
grpc.WithChainUnaryInterceptor(namespaceInterceptor),
111-
112-
// grpc.WithInsecure(),
134+
opts = append(opts,
135+
grpc.WithChainUnaryInterceptor(namespaceInterceptor),
136+
)
113137
}
138+
114139
// Set up a connection to the server.
115140
conn, err := grpc.NewClient(address, opts...)
116141
if err != nil {
117142
return nil, err
118143
}
144+
119145
client.conn = conn
120146

121147
return client, nil

pkg/client/client_test.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"net"
7+
"strconv"
8+
"testing"
9+
10+
v1 "github.com/metal-stack/masterdata-api/api/v1"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"google.golang.org/grpc"
14+
"google.golang.org/grpc/reflection"
15+
)
16+
17+
func Test_Client(t *testing.T) {
18+
const (
19+
namespace = "a"
20+
)
21+
22+
var (
23+
log = slog.Default()
24+
grpcServer = grpc.NewServer()
25+
projectMemberServer = &projectMemberServer{}
26+
tenantMemberServer = &tenantMemberServer{}
27+
)
28+
29+
v1.RegisterProjectMemberServiceServer(grpcServer, projectMemberServer)
30+
v1.RegisterTenantMemberServiceServer(grpcServer, tenantMemberServer)
31+
32+
reflection.Register(grpcServer)
33+
34+
lis, err := net.Listen("tcp", "")
35+
require.NoError(t, err)
36+
37+
go func() {
38+
err = grpcServer.Serve(lis)
39+
require.NoError(t, err)
40+
}()
41+
defer func() {
42+
grpcServer.Stop()
43+
}()
44+
45+
_, portString, err := net.SplitHostPort(lis.Addr().String())
46+
require.NoError(t, err)
47+
48+
port, err := strconv.Atoi(portString)
49+
require.NoError(t, err)
50+
51+
client, err := NewClient(t.Context(), "localhost", port, "", "", "", "", true, log, namespace)
52+
require.NoError(t, err)
53+
54+
t.Run("check namespace interceptor sets missing namespace", func(t *testing.T) {
55+
t.Run("project member", func(t *testing.T) {
56+
projectMemberServer.create = func(ctx context.Context, pmcr *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error) {
57+
assert.Equal(t, "project-a", pmcr.ProjectMember.ProjectId)
58+
assert.Equal(t, "tenant-a", pmcr.ProjectMember.TenantId)
59+
assert.Equal(t, namespace, pmcr.ProjectMember.Namespace)
60+
return &v1.ProjectMemberResponse{}, nil
61+
}
62+
projectMemberServer.find = func(ctx context.Context, pmfr *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error) {
63+
assert.Equal(t, namespace, pmfr.Namespace)
64+
return &v1.ProjectMemberListResponse{}, nil
65+
}
66+
67+
_, err = client.ProjectMember().Create(t.Context(), &v1.ProjectMemberCreateRequest{
68+
ProjectMember: &v1.ProjectMember{
69+
ProjectId: "project-a",
70+
TenantId: "tenant-a",
71+
},
72+
})
73+
require.NoError(t, err)
74+
75+
_, err = client.ProjectMember().Find(t.Context(), &v1.ProjectMemberFindRequest{})
76+
require.NoError(t, err)
77+
})
78+
79+
t.Run("tenant member", func(t *testing.T) {
80+
tenantMemberServer.create = func(ctx context.Context, tmcr *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error) {
81+
assert.Equal(t, "tenant-a", tmcr.TenantMember.TenantId)
82+
assert.Equal(t, namespace, tmcr.TenantMember.Namespace)
83+
return &v1.TenantMemberResponse{}, nil
84+
}
85+
tenantMemberServer.find = func(ctx context.Context, tmfr *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error) {
86+
assert.Equal(t, namespace, tmfr.Namespace)
87+
return &v1.TenantMemberListResponse{}, nil
88+
}
89+
90+
_, err = client.TenantMember().Create(t.Context(), &v1.TenantMemberCreateRequest{
91+
TenantMember: &v1.TenantMember{
92+
TenantId: "tenant-a",
93+
},
94+
})
95+
require.NoError(t, err)
96+
97+
_, err = client.TenantMember().Find(t.Context(), &v1.TenantMemberFindRequest{})
98+
require.NoError(t, err)
99+
})
100+
})
101+
102+
t.Run("check explicit namespace can be set anyway", func(t *testing.T) {
103+
t.Run("project member", func(t *testing.T) {
104+
projectMemberServer.create = func(ctx context.Context, pmcr *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error) {
105+
assert.Equal(t, "project-a", pmcr.ProjectMember.ProjectId)
106+
assert.Equal(t, "tenant-a", pmcr.ProjectMember.TenantId)
107+
assert.Equal(t, "b", pmcr.ProjectMember.Namespace)
108+
return &v1.ProjectMemberResponse{}, nil
109+
}
110+
projectMemberServer.find = func(ctx context.Context, pmfr *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error) {
111+
assert.Equal(t, "b", pmfr.Namespace)
112+
return &v1.ProjectMemberListResponse{}, nil
113+
}
114+
115+
_, err = client.ProjectMember().Create(t.Context(), &v1.ProjectMemberCreateRequest{
116+
ProjectMember: &v1.ProjectMember{
117+
ProjectId: "project-a",
118+
TenantId: "tenant-a",
119+
Namespace: "b",
120+
},
121+
})
122+
require.NoError(t, err)
123+
124+
_, err = client.ProjectMember().Find(t.Context(), &v1.ProjectMemberFindRequest{
125+
Namespace: "b",
126+
})
127+
require.NoError(t, err)
128+
})
129+
130+
t.Run("tenant member", func(t *testing.T) {
131+
tenantMemberServer.create = func(ctx context.Context, tmcr *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error) {
132+
assert.Equal(t, "tenant-a", tmcr.TenantMember.TenantId)
133+
assert.Equal(t, "b", tmcr.TenantMember.Namespace)
134+
return &v1.TenantMemberResponse{}, nil
135+
}
136+
tenantMemberServer.find = func(ctx context.Context, tmfr *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error) {
137+
assert.Equal(t, "b", tmfr.Namespace)
138+
return &v1.TenantMemberListResponse{}, nil
139+
}
140+
141+
_, err = client.TenantMember().Create(t.Context(), &v1.TenantMemberCreateRequest{
142+
TenantMember: &v1.TenantMember{
143+
TenantId: "tenant-a",
144+
Namespace: "b",
145+
},
146+
})
147+
require.NoError(t, err)
148+
149+
_, err = client.TenantMember().Find(t.Context(), &v1.TenantMemberFindRequest{
150+
Namespace: "b",
151+
})
152+
require.NoError(t, err)
153+
})
154+
})
155+
}
156+
157+
type projectMemberServer struct {
158+
create func(context.Context, *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error)
159+
find func(context.Context, *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error)
160+
}
161+
162+
func (t *projectMemberServer) Create(ctx context.Context, r *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error) {
163+
return t.create(ctx, r)
164+
}
165+
166+
func (t *projectMemberServer) Delete(context.Context, *v1.ProjectMemberDeleteRequest) (*v1.ProjectMemberResponse, error) {
167+
panic("unimplemented")
168+
}
169+
170+
func (t *projectMemberServer) Find(ctx context.Context, r *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error) {
171+
return t.find(ctx, r)
172+
}
173+
174+
func (t *projectMemberServer) Get(context.Context, *v1.ProjectMemberGetRequest) (*v1.ProjectMemberResponse, error) {
175+
panic("unimplemented")
176+
}
177+
178+
func (t *projectMemberServer) Update(context.Context, *v1.ProjectMemberUpdateRequest) (*v1.ProjectMemberResponse, error) {
179+
panic("unimplemented")
180+
}
181+
182+
type tenantMemberServer struct {
183+
create func(context.Context, *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error)
184+
find func(context.Context, *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error)
185+
}
186+
187+
func (t *tenantMemberServer) Create(ctx context.Context, r *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error) {
188+
return t.create(ctx, r)
189+
}
190+
191+
func (t *tenantMemberServer) Delete(context.Context, *v1.TenantMemberDeleteRequest) (*v1.TenantMemberResponse, error) {
192+
panic("unimplemented")
193+
}
194+
195+
func (t *tenantMemberServer) Find(ctx context.Context, r *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error) {
196+
return t.find(ctx, r)
197+
}
198+
199+
func (t *tenantMemberServer) Get(context.Context, *v1.TenantMemberGetRequest) (*v1.TenantMemberResponse, error) {
200+
panic("unimplemented")
201+
}
202+
203+
func (t *tenantMemberServer) Update(context.Context, *v1.TenantMemberUpdateRequest) (*v1.TenantMemberResponse, error) {
204+
panic("unimplemented")
205+
}

0 commit comments

Comments
 (0)