Skip to content

Commit b73c588

Browse files
chungthuangsssilver
authored andcommitted
TUN-5422: Define RPC to unregister session
1 parent 7e47667 commit b73c588

File tree

6 files changed

+417
-227
lines changed

6 files changed

+417
-227
lines changed

connection/quic.go

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ type QUICConnection struct {
3838
logger *zerolog.Logger
3939
httpProxy OriginProxy
4040
sessionManager datagramsession.Manager
41-
localIP net.IP
4241
}
4342

4443
// NewQUICConnection returns a new instance of QUICConnection.
@@ -75,17 +74,11 @@ func NewQUICConnection(
7574

7675
sessionManager := datagramsession.NewManager(datagramMuxer, observer.log)
7776

78-
localIP, err := getLocalIP()
79-
if err != nil {
80-
return nil, err
81-
}
82-
8377
return &QUICConnection{
8478
session: session,
8579
httpProxy: httpProxy,
8680
logger: observer.log,
8781
sessionManager: sessionManager,
88-
localIP: localIP,
8982
}, nil
9083
}
9184

@@ -197,7 +190,10 @@ func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.
197190
return nil
198191
}
199192

200-
// TODO: TUN-5422 Implement UnregisterUdpSession RPC
193+
func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error {
194+
q.sessionManager.UnregisterSession(ctx, sessionID)
195+
return nil
196+
}
201197

202198
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
203199
// the client.
@@ -292,26 +288,3 @@ func isTransferEncodingChunked(req *http.Request) bool {
292288
// separated value as well.
293289
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
294290
}
295-
296-
// TODO: TUN-5303: Find the local IP once in ingress package
297-
// TODO: TUN-5421 allow user to specify which IP to bind to
298-
func getLocalIP() (net.IP, error) {
299-
addrs, err := net.InterfaceAddrs()
300-
if err != nil {
301-
return nil, err
302-
}
303-
for _, addr := range addrs {
304-
// Find the IP that is not loop back
305-
var ip net.IP
306-
switch v := addr.(type) {
307-
case *net.IPNet:
308-
ip = v.IP
309-
case *net.IPAddr:
310-
ip = v.IP
311-
}
312-
if !ip.IsLoopback() {
313-
return ip, nil
314-
}
315-
}
316-
return nil, fmt.Errorf("cannot determine IP to bind to")
317-
}

quic/quic_protocol.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uu
247247
return resp.Err
248248
}
249249

250+
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error {
251+
return rcs.client.UnregisterUdpSession(ctx, sessionID)
252+
}
253+
250254
func (rcs *RPCClientStream) Close() {
251255
_ = rcs.client.Close()
252256
_ = rcs.transport.Close()

quic/quic_protocol_test.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,15 @@ func TestRegisterUdpSession(t *testing.T) {
131131
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
132132
assert.NoError(t, err)
133133

134-
err = rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort)
135-
assert.NoError(t, err)
134+
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort))
136135

137136
// Different sessionID, the RPC server should reject the registraion
138-
err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort)
139-
assert.Error(t, err)
137+
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort))
138+
139+
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID))
140+
141+
// Different sessionID, the RPC server should reject the unregistraion
142+
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New()))
140143

141144
rpcClientStream.Close()
142145
<-sessionRegisteredChan
@@ -161,6 +164,13 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
161164
return nil
162165
}
163166

167+
func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error {
168+
if s.sessionID != sessionID {
169+
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
170+
}
171+
return nil
172+
}
173+
164174
type mockRPCStream struct {
165175
io.ReadCloser
166176
io.WriteCloser

tunnelrpc/pogs/sessionrpc.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
type SessionManager interface {
1616
RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error
17+
UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error
1718
}
1819

1920
type SessionManager_PogsImpl struct {
@@ -60,6 +61,21 @@ func (i SessionManager_PogsImpl) RegisterUdpSession(p tunnelrpc.SessionManager_r
6061
return resp.Marshal(result)
6162
}
6263

64+
func (i SessionManager_PogsImpl) UnregisterUdpSession(p tunnelrpc.SessionManager_unregisterUdpSession) error {
65+
server.Ack(p.Options)
66+
67+
sessionIDRaw, err := p.Params.SessionId()
68+
if err != nil {
69+
return err
70+
}
71+
sessionID, err := uuid.FromBytes(sessionIDRaw)
72+
if err != nil {
73+
return err
74+
}
75+
76+
return i.impl.UnregisterUdpSession(p.Ctx, sessionID)
77+
}
78+
6379
type RegisterUdpSessionResponse struct {
6480
Err error
6581
}
@@ -116,3 +132,15 @@ func (c SessionManager_PogsClient) RegisterUdpSession(ctx context.Context, sessi
116132
}
117133
return response, nil
118134
}
135+
136+
func (c SessionManager_PogsClient) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error {
137+
client := tunnelrpc.SessionManager{Client: c.Client}
138+
promise := client.UnregisterUdpSession(ctx, func(p tunnelrpc.SessionManager_unregisterUdpSession_Params) error {
139+
if err := p.SetSessionId(sessionID[:]); err != nil {
140+
return err
141+
}
142+
return nil
143+
})
144+
_, err := promise.Struct()
145+
return err
146+
}

tunnelrpc/tunnelrpc.capnp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,5 @@ struct RegisterUdpSessionResponse {
149149

150150
interface SessionManager {
151151
registerUdpSession @0 (sessionId :Data, dstIp :Data, dstPort: UInt16) -> (result :RegisterUdpSessionResponse);
152+
unregisterUdpSession @1 (sessionId :Data) -> ();
152153
}

0 commit comments

Comments
 (0)