Skip to content

Commit 77774fe

Browse files
authored
Merge pull request #525 from jkh52/proxy-strategy-tests
proxy-server: strengthen the Backend interface.
2 parents ba7cb3b + 5da5882 commit 77774fe

File tree

6 files changed

+322
-89
lines changed

6 files changed

+322
-89
lines changed

pkg/server/backend_manager.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ type Backend interface {
8080
Send(p *client.Packet) error
8181
Recv() (*client.Packet, error)
8282
Context() context.Context
83-
GetAgentIdentifiers() (header.Identifiers, error)
83+
GetAgentID() string
84+
GetAgentIdentifiers() header.Identifiers
8485
}
8586

8687
var _ Backend = &backend{}
@@ -89,6 +90,10 @@ type backend struct {
8990
sendLock sync.Mutex
9091
recvLock sync.Mutex
9192
conn agent.AgentService_ConnectServer
93+
94+
// cached from conn.Context()
95+
id string
96+
idents header.Identifiers
9297
}
9398

9499
func (b *backend) Send(p *client.Packet) error {
@@ -125,25 +130,53 @@ func (b *backend) Context() context.Context {
125130
return b.conn.Context()
126131
}
127132

128-
func (b *backend) GetAgentIdentifiers() (header.Identifiers, error) {
133+
func (b *backend) GetAgentID() string {
134+
return b.id
135+
}
136+
137+
func (b *backend) GetAgentIdentifiers() header.Identifiers {
138+
return b.idents
139+
}
140+
141+
func getAgentID(stream agent.AgentService_ConnectServer) (string, error) {
142+
md, ok := metadata.FromIncomingContext(stream.Context())
143+
if !ok {
144+
return "", fmt.Errorf("failed to get context")
145+
}
146+
agentIDs := md.Get(header.AgentID)
147+
if len(agentIDs) != 1 {
148+
return "", fmt.Errorf("expected one agent ID in the context, got %v", agentIDs)
149+
}
150+
return agentIDs[0], nil
151+
}
152+
153+
func getAgentIdentifiers(conn agent.AgentService_ConnectServer) (header.Identifiers, error) {
129154
var agentIdentifiers header.Identifiers
130-
md, ok := metadata.FromIncomingContext(b.Context())
155+
md, ok := metadata.FromIncomingContext(conn.Context())
131156
if !ok {
132157
return agentIdentifiers, fmt.Errorf("failed to get metadata from context")
133158
}
134-
agentIDs := md.Get(header.AgentIdentifiers)
135-
if len(agentIDs) > 1 {
136-
return agentIdentifiers, fmt.Errorf("expected at most one set of agent IDs in the context, got %v", agentIDs)
159+
agentIdent := md.Get(header.AgentIdentifiers)
160+
if len(agentIdent) > 1 {
161+
return agentIdentifiers, fmt.Errorf("expected at most one set of agent identifiers in the context, got %v", agentIdent)
137162
}
138-
if len(agentIDs) == 0 {
163+
if len(agentIdent) == 0 {
139164
return agentIdentifiers, nil
140165
}
141166

142-
return header.GenAgentIdentifiers(agentIDs[0])
167+
return header.GenAgentIdentifiers(agentIdent[0])
143168
}
144169

145-
func NewBackend(conn agent.AgentService_ConnectServer) Backend {
146-
return &backend{conn: conn}
170+
func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) {
171+
agentID, err := getAgentID(conn)
172+
if err != nil {
173+
return nil, err
174+
}
175+
agentIdentifiers, err := getAgentIdentifiers(conn)
176+
if err != nil {
177+
return nil, err
178+
}
179+
return &backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil
147180
}
148181

149182
// BackendStorage is an interface to manage the storage of the backend

pkg/server/backend_manager_test.go

Lines changed: 115 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,121 @@ limitations under the License.
1717
package server
1818

1919
import (
20+
"context"
2021
"reflect"
2122
"testing"
2223

23-
"sigs.k8s.io/apiserver-network-proxy/proto/agent"
24+
"github.com/golang/mock/gomock"
25+
"google.golang.org/grpc/metadata"
26+
27+
agentmock "sigs.k8s.io/apiserver-network-proxy/proto/agent/mocks"
2428
"sigs.k8s.io/apiserver-network-proxy/proto/header"
2529
)
2630

27-
type fakeAgentServiceConnectServer struct {
28-
agent.AgentService_ConnectServer
31+
func mockAgentConn(ctrl *gomock.Controller, agentID string, agentIdentifiers []string) *agentmock.MockAgentService_ConnectServer {
32+
agentConn := agentmock.NewMockAgentService_ConnectServer(ctrl)
33+
agentConnMD := metadata.MD{
34+
":authority": []string{"127.0.0.1:8091"},
35+
"agentid": []string{agentID},
36+
"agentidentifiers": agentIdentifiers,
37+
"content-type": []string{"application/grpc"},
38+
"user-agent": []string{"grpc-go/1.42.0"},
39+
}
40+
agentConnCtx := metadata.NewIncomingContext(context.Background(), agentConnMD)
41+
agentConn.EXPECT().Context().Return(agentConnCtx).AnyTimes()
42+
return agentConn
2943
}
3044

31-
func TestAddRemoveBackends(t *testing.T) {
32-
backend1 := NewBackend(new(fakeAgentServiceConnectServer))
33-
backend12 := NewBackend(new(fakeAgentServiceConnectServer))
34-
backend2 := NewBackend(new(fakeAgentServiceConnectServer))
35-
backend22 := NewBackend(new(fakeAgentServiceConnectServer))
36-
backend3 := NewBackend(new(fakeAgentServiceConnectServer))
45+
func TestNewBackend(t *testing.T) {
46+
ctrl := gomock.NewController(t)
47+
defer ctrl.Finish()
48+
49+
testCases := []struct {
50+
desc string
51+
ids []string
52+
idents []string
53+
wantErr bool
54+
}{
55+
{
56+
desc: "no agentID",
57+
wantErr: true,
58+
},
59+
{
60+
desc: "multiple agentID",
61+
ids: []string{"agent-id", "agent-id"},
62+
wantErr: true,
63+
},
64+
{
65+
desc: "multiple identifiers",
66+
ids: []string{"agent-id"},
67+
idents: []string{"host=localhost", "host=localhost"},
68+
wantErr: true,
69+
},
70+
{
71+
desc: "invalid identifiers",
72+
ids: []string{"agent-id"},
73+
idents: []string{";"},
74+
wantErr: true,
75+
},
76+
{
77+
desc: "success",
78+
ids: []string{"agent-id"},
79+
},
80+
{
81+
desc: "success with identifiers",
82+
ids: []string{"agent-id"},
83+
idents: []string{"host=localhost&host=node1.mydomain.com&cidr=127.0.0.1/16&ipv4=1.2.3.4&ipv4=5.6.7.8&ipv6=:::::&default-route=true"},
84+
},
85+
}
86+
87+
for _, tc := range testCases {
88+
t.Run(tc.desc, func(t *testing.T) {
89+
90+
agentConn := agentmock.NewMockAgentService_ConnectServer(ctrl)
91+
agentConnMD := metadata.MD{
92+
":authority": []string{"127.0.0.1:8091"},
93+
"agentid": tc.ids,
94+
"agentidentifiers": tc.idents,
95+
"content-type": []string{"application/grpc"},
96+
"user-agent": []string{"grpc-go/1.42.0"},
97+
}
98+
agentConnCtx := metadata.NewIncomingContext(context.Background(), agentConnMD)
99+
agentConn.EXPECT().Context().Return(agentConnCtx).AnyTimes()
100+
101+
_, err := NewBackend(agentConn)
102+
if gotErr := (err != nil); gotErr != tc.wantErr {
103+
t.Errorf("NewBackend got err %q; wantErr = %t", err, tc.wantErr)
104+
}
105+
})
106+
}
107+
}
108+
109+
func TestAddRemoveBackendsWithDefaultStrategy(t *testing.T) {
110+
ctrl := gomock.NewController(t)
111+
defer ctrl.Finish()
112+
113+
backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{}))
114+
backend12, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{}))
115+
backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{}))
116+
backend22, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{}))
117+
backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{}))
37118

38119
p := NewDefaultBackendManager()
39120

40121
p.AddBackend("agent1", header.UID, backend1)
41122
p.RemoveBackend("agent1", header.UID, backend1)
42123
expectedBackends := make(map[string][]Backend)
43124
expectedAgentIDs := []string{}
125+
expectedDefaultRouteAgentIDs := []string(nil)
44126
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
45127
t.Errorf("expected %v, got %v", e, a)
46128
}
47129
if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) {
48130
t.Errorf("expected %v, got %v", e, a)
49131
}
132+
if e, a := expectedDefaultRouteAgentIDs, p.defaultRouteAgentIDs; !reflect.DeepEqual(e, a) {
133+
t.Errorf("expected %v, got %v", e, a)
134+
}
50135

51136
p = NewDefaultBackendManager()
52137
p.AddBackend("agent1", header.UID, backend1)
@@ -66,34 +151,42 @@ func TestAddRemoveBackends(t *testing.T) {
66151
"agent3": {backend3},
67152
}
68153
expectedAgentIDs = []string{"agent1", "agent3"}
154+
expectedDefaultRouteAgentIDs = []string(nil)
69155
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
70156
t.Errorf("expected %v, got %v", e, a)
71157
}
72158
if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) {
73159
t.Errorf("expected %v, got %v", e, a)
74160
}
161+
if e, a := expectedDefaultRouteAgentIDs, p.defaultRouteAgentIDs; !reflect.DeepEqual(e, a) {
162+
t.Errorf("expected %v, got %v", e, a)
163+
}
75164
}
76165

77-
func TestAddRemoveBackendsWithDefaultRoute(t *testing.T) {
78-
backend1 := NewBackend(new(fakeAgentServiceConnectServer))
79-
backend12 := NewBackend(new(fakeAgentServiceConnectServer))
80-
backend2 := NewBackend(new(fakeAgentServiceConnectServer))
81-
backend22 := NewBackend(new(fakeAgentServiceConnectServer))
82-
backend3 := NewBackend(new(fakeAgentServiceConnectServer))
166+
func TestAddRemoveBackendsWithDefaultRouteStrategy(t *testing.T) {
167+
ctrl := gomock.NewController(t)
168+
defer ctrl.Finish()
169+
170+
backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"default-route"}))
171+
backend12, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"default-route"}))
172+
backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route"}))
173+
backend22, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route"}))
174+
backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"default-route"}))
83175

84176
p := NewDefaultRouteBackendManager()
85177

86178
p.AddBackend("agent1", header.DefaultRoute, backend1)
87179
p.RemoveBackend("agent1", header.DefaultRoute, backend1)
88180
expectedBackends := make(map[string][]Backend)
89181
expectedAgentIDs := []string{}
182+
expectedDefaultRouteAgentIDs := []string{}
90183
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
91184
t.Errorf("expected %v, got %v", e, a)
92185
}
93186
if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) {
94187
t.Errorf("expected %v, got %v", e, a)
95188
}
96-
if e, a := expectedAgentIDs, p.defaultRouteAgentIDs; !reflect.DeepEqual(e, a) {
189+
if e, a := expectedDefaultRouteAgentIDs, p.defaultRouteAgentIDs; !reflect.DeepEqual(e, a) {
97190
t.Errorf("expected %v, got %v", e, a)
98191
}
99192

@@ -108,18 +201,22 @@ func TestAddRemoveBackendsWithDefaultRoute(t *testing.T) {
108201
p.RemoveBackend("agent2", header.DefaultRoute, backend22)
109202
p.RemoveBackend("agent2", header.DefaultRoute, backend2)
110203
p.RemoveBackend("agent1", header.DefaultRoute, backend1)
111-
// This is invalid. agent1 doesn't have conn3. This should be a no-op.
204+
// This is invalid. agent1 doesn't have backend3. This should be a no-op.
112205
p.RemoveBackend("agent1", header.DefaultRoute, backend3)
113206

114207
expectedBackends = map[string][]Backend{
115208
"agent1": {backend12},
116209
"agent3": {backend3},
117210
}
118-
expectedDefaultRouteAgentIDs := []string{"agent1", "agent3"}
211+
expectedAgentIDs = []string{"agent1", "agent3"}
212+
expectedDefaultRouteAgentIDs = []string{"agent1", "agent3"}
119213

120214
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
121215
t.Errorf("expected %v, got %v", e, a)
122216
}
217+
if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) {
218+
t.Errorf("expected %v, got %v", e, a)
219+
}
123220
if e, a := expectedDefaultRouteAgentIDs, p.defaultRouteAgentIDs; !reflect.DeepEqual(e, a) {
124221
t.Errorf("expected %v, got %v", e, a)
125222
}

0 commit comments

Comments
 (0)