@@ -37,7 +37,7 @@ type Server struct {
37
37
IdleTimeout time.Duration // connection timeout when no activity, none if empty
38
38
MaxTimeout time.Duration // absolute connection timeout, none if empty
39
39
40
- channelHandlers map [string ]channelHandler
40
+ channelHandlers map [string ]ChannelHandler // fallback channel handlers
41
41
requestHandlers map [string ]RequestHandler
42
42
43
43
listenerWg sync.WaitGroup
@@ -51,8 +51,7 @@ type RequestHandler interface {
51
51
HandleRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
52
52
}
53
53
54
- // internal for now
55
- type channelHandler func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
54
+ type ChannelHandler func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
56
55
57
56
func (srv * Server ) ensureHostSigner () error {
58
57
if len (srv .HostSigners ) == 0 {
@@ -68,13 +67,17 @@ func (srv *Server) ensureHostSigner() error {
68
67
func (srv * Server ) ensureHandlers () {
69
68
srv .mu .Lock ()
70
69
defer srv .mu .Unlock ()
71
- srv .requestHandlers = map [string ]RequestHandler {
72
- "tcpip-forward" : forwardedTCPHandler {},
73
- "cancel-tcpip-forward" : forwardedTCPHandler {},
70
+ if srv .requestHandlers == nil {
71
+ srv .requestHandlers = map [string ]RequestHandler {
72
+ "tcpip-forward" : forwardedTCPHandler {},
73
+ "cancel-tcpip-forward" : forwardedTCPHandler {},
74
+ }
74
75
}
75
- srv .channelHandlers = map [string ]channelHandler {
76
- "session" : sessionHandler ,
77
- "direct-tcpip" : directTcpipHandler ,
76
+ if srv .channelHandlers == nil {
77
+ srv .channelHandlers = map [string ]ChannelHandler {
78
+ "session" : sessionHandler ,
79
+ "direct-tcpip" : directTcpipHandler ,
80
+ }
78
81
}
79
82
}
80
83
@@ -186,12 +189,6 @@ func (srv *Server) Serve(l net.Listener) error {
186
189
if srv .Handler == nil {
187
190
srv .Handler = DefaultHandler
188
191
}
189
- if srv .channelHandlers == nil {
190
- srv .channelHandlers = map [string ]channelHandler {
191
- "session" : sessionHandler ,
192
- "direct-tcpip" : directTcpipHandler ,
193
- }
194
- }
195
192
var tempDelay time.Duration
196
193
197
194
srv .trackListener (l , true )
@@ -222,6 +219,18 @@ func (srv *Server) Serve(l net.Listener) error {
222
219
}
223
220
}
224
221
222
+ func (srv * Server ) SetChannelHandler (kind string , handler ChannelHandler ) {
223
+ srv .ensureHandlers ()
224
+ srv .mu .Lock ()
225
+ defer srv .mu .Unlock ()
226
+ srv .channelHandlers [kind ] = handler
227
+ }
228
+
229
+ func (srv * Server ) ChannelHandler (kind string ) ChannelHandler {
230
+ srv .ensureHandlers ()
231
+ return srv .channelHandlers [kind ]
232
+ }
233
+
225
234
func (srv * Server ) handleConn (newConn net.Conn ) {
226
235
if srv .ConnCallback != nil {
227
236
cbConn := srv .ConnCallback (newConn )
@@ -256,7 +265,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
256
265
go srv .handleRequests (ctx , reqs )
257
266
for ch := range chans {
258
267
handler , found := srv .channelHandlers [ch .ChannelType ()]
259
- if ! found {
268
+ if ! found || handler == nil {
269
+ if defaultHandler , found := srv .channelHandlers ["default" ]; found {
270
+ handler = defaultHandler
271
+ }
272
+ }
273
+ if handler == nil {
260
274
ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
261
275
continue
262
276
}
0 commit comments