Skip to content

Commit 10a34a7

Browse files
committed
Add client test for namespace interceptor.
1 parent 49f1151 commit 10a34a7

File tree

2 files changed

+275
-46
lines changed

2 files changed

+275
-46
lines changed

pkg/client/client.go

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@ import (
44
"context"
55
"crypto/tls"
66
"crypto/x509"
7-
"errors"
87
"fmt"
98
"log/slog"
109
"os"
1110

12-
"github.com/metal-stack/masterdata-api/pkg/auth"
1311
"google.golang.org/grpc"
1412
"google.golang.org/grpc/credentials"
13+
grpcinsecure "google.golang.org/grpc/credentials/insecure"
1514

1615
v1 "github.com/metal-stack/masterdata-api/api/v1"
16+
"github.com/metal-stack/masterdata-api/pkg/auth"
1717
)
1818

1919
// Client defines the client API
@@ -35,7 +35,6 @@ type GRPCClient struct {
3535

3636
// NewClient creates a new client for the services for the given address, with the certificate and hmac.
3737
func NewClient(ctx context.Context, hostname string, port int, certFile string, keyFile string, caFile string, hmacKey string, insecure bool, logger *slog.Logger, namespace string) (Client, error) {
38-
3938
address := fmt.Sprintf("%s:%d", hostname, port)
4039

4140
certPool, err := x509.SystemCertPool()
@@ -55,67 +54,92 @@ func NewClient(ctx context.Context, hostname string, port int, certFile string,
5554
}
5655
}
5756

58-
clientCertificate, err := tls.LoadX509KeyPair(certFile, keyFile)
59-
if err != nil {
60-
return nil, fmt.Errorf("could not load client key pair: %w", err)
57+
var (
58+
certificates []tls.Certificate
59+
opts []grpc.DialOption
60+
)
61+
62+
if certFile != "" && keyFile != "" {
63+
clientCertificate, err := tls.LoadX509KeyPair(certFile, keyFile)
64+
if err != nil {
65+
return nil, fmt.Errorf("could not load client key pair: %w", err)
66+
}
67+
68+
certificates = append(certificates, clientCertificate)
69+
70+
creds := credentials.NewTLS(&tls.Config{
71+
ServerName: hostname,
72+
Certificates: certificates,
73+
RootCAs: certPool,
74+
MinVersion: tls.VersionTLS12,
75+
InsecureSkipVerify: insecure, // nolint:gosec
76+
})
77+
78+
opts = append(opts,
79+
// oauth.NewOauthAccess requires the configuration of transport
80+
// credentials.
81+
grpc.WithTransportCredentials(creds),
82+
)
83+
} else {
84+
opts = append(opts,
85+
grpc.WithTransportCredentials(grpcinsecure.NewCredentials()),
86+
)
6187
}
6288

63-
creds := credentials.NewTLS(&tls.Config{
64-
ServerName: hostname,
65-
Certificates: []tls.Certificate{clientCertificate},
66-
RootCAs: certPool,
67-
MinVersion: tls.VersionTLS12,
68-
InsecureSkipVerify: insecure, // nolint:gosec
69-
})
89+
if hmacKey != "" {
90+
// Set up the credentials for the connection.
91+
perRPCHMACAuthenticator, err := auth.NewHMACAuther(hmacKey, auth.EditUser)
92+
if err != nil {
93+
return nil, fmt.Errorf("failed to create hmac-authenticator: %w", err)
94+
}
7095

71-
if hmacKey == "" {
72-
return nil, errors.New("no hmac-key specified")
96+
opts = append(opts,
97+
// In addition to the following grpc.DialOption, callers may also use
98+
// the grpc.CallOption grpc.PerRPCCredentials with the RPC invocation
99+
// itself.
100+
// See: https://godoc.org/google.golang.org/grpc#PerRPCCredentials
101+
grpc.WithPerRPCCredentials(perRPCHMACAuthenticator))
73102
}
74103

75104
client := GRPCClient{
76105
log: logger,
77106
hmacKey: hmacKey,
78107
}
79108

80-
// Set up the credentials for the connection.
81-
perRPCHMACAuthenticator, err := auth.NewHMACAuther(hmacKey, auth.EditUser)
82-
if err != nil {
83-
return nil, fmt.Errorf("failed to create hmac-authenticator: %w", err)
84-
}
85-
86-
namespaceInterceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
87-
switch r := req.(type) {
88-
case *v1.TenantMemberCreateRequest:
89-
r.TenantMember.Namespace = namespace
90-
case *v1.TenantMemberFindRequest:
91-
r.Namespace = namespace
92-
case *v1.ProjectMemberCreateRequest:
93-
r.ProjectMember.Namespace = namespace
94-
case *v1.ProjectMemberFindRequest:
95-
r.Namespace = namespace
109+
if namespace != "" {
110+
namespaceInterceptor := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
111+
switch r := req.(type) {
112+
case *v1.TenantMemberCreateRequest:
113+
if r.TenantMember.Namespace == "" {
114+
r.TenantMember.Namespace = namespace
115+
}
116+
case *v1.TenantMemberFindRequest:
117+
if r.Namespace == "" {
118+
r.Namespace = namespace
119+
}
120+
case *v1.ProjectMemberCreateRequest:
121+
if r.ProjectMember.Namespace == "" {
122+
r.ProjectMember.Namespace = namespace
123+
}
124+
case *v1.ProjectMemberFindRequest:
125+
if r.Namespace == "" {
126+
r.Namespace = namespace
127+
}
128+
}
129+
return invoker(ctx, method, req, reply, cc, opts...)
96130
}
97-
return invoker(ctx, method, req, reply, cc, opts...)
98-
}
99131

100-
opts := []grpc.DialOption{
101-
// In addition to the following grpc.DialOption, callers may also use
102-
// the grpc.CallOption grpc.PerRPCCredentials with the RPC invocation
103-
// itself.
104-
// See: https://godoc.org/google.golang.org/grpc#PerRPCCredentials
105-
grpc.WithPerRPCCredentials(perRPCHMACAuthenticator),
106-
// oauth.NewOauthAccess requires the configuration of transport
107-
// credentials.
108-
grpc.WithTransportCredentials(creds),
109-
110-
grpc.WithChainUnaryInterceptor(namespaceInterceptor),
111-
112-
// grpc.WithInsecure(),
132+
opts = append(opts,
133+
grpc.WithChainUnaryInterceptor(namespaceInterceptor),
134+
)
113135
}
136+
114137
// Set up a connection to the server.
115138
conn, err := grpc.NewClient(address, opts...)
116139
if err != nil {
117140
return nil, err
118141
}
142+
119143
client.conn = conn
120144

121145
return client, nil

pkg/client/client_test.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"net"
7+
"strconv"
8+
"testing"
9+
10+
v1 "github.com/metal-stack/masterdata-api/api/v1"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"google.golang.org/grpc"
14+
"google.golang.org/grpc/reflection"
15+
)
16+
17+
func Test_Client(t *testing.T) {
18+
const (
19+
namespace = "a"
20+
)
21+
22+
var (
23+
log = slog.Default()
24+
grpcServer = grpc.NewServer()
25+
projectMemberServer = &projectMemberServer{}
26+
tenantMemberServer = &tenantMemberServer{}
27+
)
28+
29+
v1.RegisterProjectMemberServiceServer(grpcServer, projectMemberServer)
30+
v1.RegisterTenantMemberServiceServer(grpcServer, tenantMemberServer)
31+
32+
reflection.Register(grpcServer)
33+
34+
lis, err := net.Listen("tcp", "")
35+
require.NoError(t, err)
36+
37+
go func() {
38+
err = grpcServer.Serve(lis)
39+
require.NoError(t, err)
40+
}()
41+
defer func() {
42+
grpcServer.Stop()
43+
}()
44+
45+
_, portString, err := net.SplitHostPort(lis.Addr().String())
46+
require.NoError(t, err)
47+
48+
port, err := strconv.Atoi(portString)
49+
require.NoError(t, err)
50+
51+
client, err := NewClient(t.Context(), "localhost", port, "", "", "", "", true, log, namespace)
52+
require.NoError(t, err)
53+
54+
t.Run("check namespace interceptor sets missing namespace", func(t *testing.T) {
55+
t.Run("project member", func(t *testing.T) {
56+
projectMemberServer.create = func(ctx context.Context, pmcr *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error) {
57+
assert.Equal(t, "project-a", pmcr.ProjectMember.ProjectId)
58+
assert.Equal(t, "tenant-a", pmcr.ProjectMember.TenantId)
59+
assert.Equal(t, namespace, pmcr.ProjectMember.Namespace)
60+
return &v1.ProjectMemberResponse{}, nil
61+
}
62+
projectMemberServer.find = func(ctx context.Context, pmfr *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error) {
63+
assert.Equal(t, namespace, pmfr.Namespace)
64+
return &v1.ProjectMemberListResponse{}, nil
65+
}
66+
67+
_, err = client.ProjectMember().Create(t.Context(), &v1.ProjectMemberCreateRequest{
68+
ProjectMember: &v1.ProjectMember{
69+
ProjectId: "project-a",
70+
TenantId: "tenant-a",
71+
},
72+
})
73+
require.NoError(t, err)
74+
75+
_, err = client.ProjectMember().Find(t.Context(), &v1.ProjectMemberFindRequest{})
76+
require.NoError(t, err)
77+
})
78+
79+
t.Run("tenant member", func(t *testing.T) {
80+
tenantMemberServer.create = func(ctx context.Context, tmcr *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error) {
81+
assert.Equal(t, "tenant-a", tmcr.TenantMember.TenantId)
82+
assert.Equal(t, namespace, tmcr.TenantMember.Namespace)
83+
return &v1.TenantMemberResponse{}, nil
84+
}
85+
tenantMemberServer.find = func(ctx context.Context, tmfr *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error) {
86+
assert.Equal(t, namespace, tmfr.Namespace)
87+
return &v1.TenantMemberListResponse{}, nil
88+
}
89+
90+
_, err = client.TenantMember().Create(t.Context(), &v1.TenantMemberCreateRequest{
91+
TenantMember: &v1.TenantMember{
92+
TenantId: "tenant-a",
93+
},
94+
})
95+
require.NoError(t, err)
96+
97+
_, err = client.TenantMember().Find(t.Context(), &v1.TenantMemberFindRequest{})
98+
require.NoError(t, err)
99+
})
100+
})
101+
102+
t.Run("check explicit namespace can be set anyway", func(t *testing.T) {
103+
t.Run("project member", func(t *testing.T) {
104+
projectMemberServer.create = func(ctx context.Context, pmcr *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error) {
105+
assert.Equal(t, "project-a", pmcr.ProjectMember.ProjectId)
106+
assert.Equal(t, "tenant-a", pmcr.ProjectMember.TenantId)
107+
assert.Equal(t, "b", pmcr.ProjectMember.Namespace)
108+
return &v1.ProjectMemberResponse{}, nil
109+
}
110+
projectMemberServer.find = func(ctx context.Context, pmfr *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error) {
111+
assert.Equal(t, "b", pmfr.Namespace)
112+
return &v1.ProjectMemberListResponse{}, nil
113+
}
114+
115+
_, err = client.ProjectMember().Create(t.Context(), &v1.ProjectMemberCreateRequest{
116+
ProjectMember: &v1.ProjectMember{
117+
ProjectId: "project-a",
118+
TenantId: "tenant-a",
119+
Namespace: "b",
120+
},
121+
})
122+
require.NoError(t, err)
123+
124+
_, err = client.ProjectMember().Find(t.Context(), &v1.ProjectMemberFindRequest{
125+
Namespace: "b",
126+
})
127+
require.NoError(t, err)
128+
})
129+
130+
t.Run("tenant member", func(t *testing.T) {
131+
tenantMemberServer.create = func(ctx context.Context, tmcr *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error) {
132+
assert.Equal(t, "tenant-a", tmcr.TenantMember.TenantId)
133+
assert.Equal(t, "b", tmcr.TenantMember.Namespace)
134+
return &v1.TenantMemberResponse{}, nil
135+
}
136+
tenantMemberServer.find = func(ctx context.Context, tmfr *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error) {
137+
assert.Equal(t, "b", tmfr.Namespace)
138+
return &v1.TenantMemberListResponse{}, nil
139+
}
140+
141+
_, err = client.TenantMember().Create(t.Context(), &v1.TenantMemberCreateRequest{
142+
TenantMember: &v1.TenantMember{
143+
TenantId: "tenant-a",
144+
Namespace: "b",
145+
},
146+
})
147+
require.NoError(t, err)
148+
149+
_, err = client.TenantMember().Find(t.Context(), &v1.TenantMemberFindRequest{
150+
Namespace: "b",
151+
})
152+
require.NoError(t, err)
153+
})
154+
})
155+
}
156+
157+
type projectMemberServer struct {
158+
create func(context.Context, *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error)
159+
find func(context.Context, *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error)
160+
}
161+
162+
func (t *projectMemberServer) Create(ctx context.Context, r *v1.ProjectMemberCreateRequest) (*v1.ProjectMemberResponse, error) {
163+
return t.create(ctx, r)
164+
}
165+
166+
func (t *projectMemberServer) Delete(context.Context, *v1.ProjectMemberDeleteRequest) (*v1.ProjectMemberResponse, error) {
167+
panic("unimplemented")
168+
}
169+
170+
func (t *projectMemberServer) Find(ctx context.Context, r *v1.ProjectMemberFindRequest) (*v1.ProjectMemberListResponse, error) {
171+
return t.find(ctx, r)
172+
}
173+
174+
func (t *projectMemberServer) Get(context.Context, *v1.ProjectMemberGetRequest) (*v1.ProjectMemberResponse, error) {
175+
panic("unimplemented")
176+
}
177+
178+
func (t *projectMemberServer) Update(context.Context, *v1.ProjectMemberUpdateRequest) (*v1.ProjectMemberResponse, error) {
179+
panic("unimplemented")
180+
}
181+
182+
type tenantMemberServer struct {
183+
create func(context.Context, *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error)
184+
find func(context.Context, *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error)
185+
}
186+
187+
func (t *tenantMemberServer) Create(ctx context.Context, r *v1.TenantMemberCreateRequest) (*v1.TenantMemberResponse, error) {
188+
return t.create(ctx, r)
189+
}
190+
191+
func (t *tenantMemberServer) Delete(context.Context, *v1.TenantMemberDeleteRequest) (*v1.TenantMemberResponse, error) {
192+
panic("unimplemented")
193+
}
194+
195+
func (t *tenantMemberServer) Find(ctx context.Context, r *v1.TenantMemberFindRequest) (*v1.TenantMemberListResponse, error) {
196+
return t.find(ctx, r)
197+
}
198+
199+
func (t *tenantMemberServer) Get(context.Context, *v1.TenantMemberGetRequest) (*v1.TenantMemberResponse, error) {
200+
panic("unimplemented")
201+
}
202+
203+
func (t *tenantMemberServer) Update(context.Context, *v1.TenantMemberUpdateRequest) (*v1.TenantMemberResponse, error) {
204+
panic("unimplemented")
205+
}

0 commit comments

Comments
 (0)