Skip to content

Commit 75b6954

Browse files
committed
Merge remote-tracking branch 'moul/dev/moul/configurable-handlers' into configurable-handlers
2 parents a9daacc + 570aa23 commit 75b6954

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

server.go

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type Server struct {
3737
IdleTimeout time.Duration // connection timeout when no activity, none if empty
3838
MaxTimeout time.Duration // absolute connection timeout, none if empty
3939

40-
channelHandlers map[string]channelHandler
40+
channelHandlers map[string]ChannelHandler // fallback channel handlers
4141
requestHandlers map[string]RequestHandler
4242

4343
listenerWg sync.WaitGroup
@@ -51,8 +51,7 @@ type RequestHandler interface {
5151
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
5252
}
5353

54-
// internal for now
55-
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
54+
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
5655

5756
func (srv *Server) ensureHostSigner() error {
5857
if len(srv.HostSigners) == 0 {
@@ -68,13 +67,17 @@ func (srv *Server) ensureHostSigner() error {
6867
func (srv *Server) ensureHandlers() {
6968
srv.mu.Lock()
7069
defer srv.mu.Unlock()
71-
srv.requestHandlers = map[string]RequestHandler{
72-
"tcpip-forward": forwardedTCPHandler{},
73-
"cancel-tcpip-forward": forwardedTCPHandler{},
70+
if srv.requestHandlers == nil {
71+
srv.requestHandlers = map[string]RequestHandler{
72+
"tcpip-forward": forwardedTCPHandler{},
73+
"cancel-tcpip-forward": forwardedTCPHandler{},
74+
}
7475
}
75-
srv.channelHandlers = map[string]channelHandler{
76-
"session": sessionHandler,
77-
"direct-tcpip": directTcpipHandler,
76+
if srv.channelHandlers == nil {
77+
srv.channelHandlers = map[string]ChannelHandler{
78+
"session": sessionHandler,
79+
"direct-tcpip": directTcpipHandler,
80+
}
7881
}
7982
}
8083

@@ -186,12 +189,6 @@ func (srv *Server) Serve(l net.Listener) error {
186189
if srv.Handler == nil {
187190
srv.Handler = DefaultHandler
188191
}
189-
if srv.channelHandlers == nil {
190-
srv.channelHandlers = map[string]channelHandler{
191-
"session": sessionHandler,
192-
"direct-tcpip": directTcpipHandler,
193-
}
194-
}
195192
var tempDelay time.Duration
196193

197194
srv.trackListener(l, true)
@@ -222,6 +219,18 @@ func (srv *Server) Serve(l net.Listener) error {
222219
}
223220
}
224221

222+
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
223+
srv.ensureHandlers()
224+
srv.mu.Lock()
225+
defer srv.mu.Unlock()
226+
srv.channelHandlers[kind] = handler
227+
}
228+
229+
func (srv *Server) ChannelHandler(kind string) ChannelHandler {
230+
srv.ensureHandlers()
231+
return srv.channelHandlers[kind]
232+
}
233+
225234
func (srv *Server) handleConn(newConn net.Conn) {
226235
if srv.ConnCallback != nil {
227236
cbConn := srv.ConnCallback(newConn)
@@ -256,7 +265,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
256265
go srv.handleRequests(ctx, reqs)
257266
for ch := range chans {
258267
handler, found := srv.channelHandlers[ch.ChannelType()]
259-
if !found {
268+
if !found || handler == nil {
269+
if defaultHandler, found := srv.channelHandlers["default"]; found {
270+
handler = defaultHandler
271+
}
272+
}
273+
if handler == nil {
260274
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
261275
continue
262276
}

session_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
1919
if e != nil {
2020
return e
2121
}
22-
srv.channelHandlers = map[string]channelHandler{
22+
srv.channelHandlers = map[string]ChannelHandler{
2323
"session": sessionHandler,
2424
"direct-tcpip": directTcpipHandler,
2525
}

0 commit comments

Comments
 (0)