Skip to content

Commit 8b3cdd4

Browse files
committed
feat: configurable server handlers
1 parent cbabf54 commit 8b3cdd4

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

server.go

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type Server struct {
3434
IdleTimeout time.Duration // connection timeout when no activity, none if empty
3535
MaxTimeout time.Duration // absolute connection timeout, none if empty
3636

37-
channelHandlers map[string]channelHandler
37+
channelHandlers map[string]ChannelHandler // fallback channel handlers
3838
requestHandlers map[string]RequestHandler
3939

4040
listenerWg sync.WaitGroup
@@ -48,8 +48,7 @@ type RequestHandler interface {
4848
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
4949
}
5050

51-
// internal for now
52-
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
51+
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
5352

5453
func (srv *Server) ensureHostSigner() error {
5554
if len(srv.HostSigners) == 0 {
@@ -65,13 +64,17 @@ func (srv *Server) ensureHostSigner() error {
6564
func (srv *Server) ensureHandlers() {
6665
srv.mu.Lock()
6766
defer srv.mu.Unlock()
68-
srv.requestHandlers = map[string]RequestHandler{
69-
"tcpip-forward": forwardedTCPHandler{},
70-
"cancel-tcpip-forward": forwardedTCPHandler{},
67+
if srv.requestHandlers == nil {
68+
srv.requestHandlers = map[string]RequestHandler{
69+
"tcpip-forward": forwardedTCPHandler{},
70+
"cancel-tcpip-forward": forwardedTCPHandler{},
71+
}
7172
}
72-
srv.channelHandlers = map[string]channelHandler{
73-
"session": sessionHandler,
74-
"direct-tcpip": directTcpipHandler,
73+
if srv.channelHandlers == nil {
74+
srv.channelHandlers = map[string]ChannelHandler{
75+
"session": sessionHandler,
76+
"direct-tcpip": directTcpipHandler,
77+
}
7578
}
7679
}
7780

@@ -170,12 +173,6 @@ func (srv *Server) Serve(l net.Listener) error {
170173
if srv.Handler == nil {
171174
srv.Handler = DefaultHandler
172175
}
173-
if srv.channelHandlers == nil {
174-
srv.channelHandlers = map[string]channelHandler{
175-
"session": sessionHandler,
176-
"direct-tcpip": directTcpipHandler,
177-
}
178-
}
179176
var tempDelay time.Duration
180177

181178
srv.trackListener(l, true)
@@ -206,6 +203,18 @@ func (srv *Server) Serve(l net.Listener) error {
206203
}
207204
}
208205

206+
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
207+
srv.ensureHandlers()
208+
srv.mu.Lock()
209+
defer srv.mu.Unlock()
210+
srv.channelHandlers[kind] = handler
211+
}
212+
213+
func (srv *Server) GetChannelHandler(kind string) ChannelHandler {
214+
srv.ensureHandlers()
215+
return srv.channelHandlers[kind]
216+
}
217+
209218
func (srv *Server) handleConn(newConn net.Conn) {
210219
if srv.ConnCallback != nil {
211220
cbConn := srv.ConnCallback(newConn)
@@ -240,7 +249,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
240249
go srv.handleRequests(ctx, reqs)
241250
for ch := range chans {
242251
handler, found := srv.channelHandlers[ch.ChannelType()]
243-
if !found {
252+
if !found || handler == nil {
253+
if defaultHandler, found := srv.channelHandlers["default"]; found {
254+
handler = defaultHandler
255+
}
256+
}
257+
if handler == nil {
244258
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
245259
continue
246260
}

session_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
1919
if e != nil {
2020
return e
2121
}
22-
srv.channelHandlers = map[string]channelHandler{
22+
srv.channelHandlers = map[string]ChannelHandler{
2323
"session": sessionHandler,
2424
"direct-tcpip": directTcpipHandler,
2525
}

0 commit comments

Comments
 (0)