@@ -34,7 +34,7 @@ type Server struct {
34
34
IdleTimeout time.Duration // connection timeout when no activity, none if empty
35
35
MaxTimeout time.Duration // absolute connection timeout, none if empty
36
36
37
- channelHandlers map [string ]channelHandler
37
+ channelHandlers map [string ]ChannelHandler // fallback channel handlers
38
38
requestHandlers map [string ]RequestHandler
39
39
40
40
listenerWg sync.WaitGroup
@@ -48,8 +48,7 @@ type RequestHandler interface {
48
48
HandleRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
49
49
}
50
50
51
- // internal for now
52
- type channelHandler func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
51
+ type ChannelHandler func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
53
52
54
53
func (srv * Server ) ensureHostSigner () error {
55
54
if len (srv .HostSigners ) == 0 {
@@ -65,13 +64,17 @@ func (srv *Server) ensureHostSigner() error {
65
64
func (srv * Server ) ensureHandlers () {
66
65
srv .mu .Lock ()
67
66
defer srv .mu .Unlock ()
68
- srv .requestHandlers = map [string ]RequestHandler {
69
- "tcpip-forward" : forwardedTCPHandler {},
70
- "cancel-tcpip-forward" : forwardedTCPHandler {},
67
+ if srv .requestHandlers == nil {
68
+ srv .requestHandlers = map [string ]RequestHandler {
69
+ "tcpip-forward" : forwardedTCPHandler {},
70
+ "cancel-tcpip-forward" : forwardedTCPHandler {},
71
+ }
71
72
}
72
- srv .channelHandlers = map [string ]channelHandler {
73
- "session" : sessionHandler ,
74
- "direct-tcpip" : directTcpipHandler ,
73
+ if srv .channelHandlers == nil {
74
+ srv .channelHandlers = map [string ]ChannelHandler {
75
+ "session" : sessionHandler ,
76
+ "direct-tcpip" : directTcpipHandler ,
77
+ }
75
78
}
76
79
}
77
80
@@ -170,12 +173,6 @@ func (srv *Server) Serve(l net.Listener) error {
170
173
if srv .Handler == nil {
171
174
srv .Handler = DefaultHandler
172
175
}
173
- if srv .channelHandlers == nil {
174
- srv .channelHandlers = map [string ]channelHandler {
175
- "session" : sessionHandler ,
176
- "direct-tcpip" : directTcpipHandler ,
177
- }
178
- }
179
176
var tempDelay time.Duration
180
177
181
178
srv .trackListener (l , true )
@@ -206,6 +203,18 @@ func (srv *Server) Serve(l net.Listener) error {
206
203
}
207
204
}
208
205
206
+ func (srv * Server ) SetChannelHandler (kind string , handler ChannelHandler ) {
207
+ srv .ensureHandlers ()
208
+ srv .mu .Lock ()
209
+ defer srv .mu .Unlock ()
210
+ srv .channelHandlers [kind ] = handler
211
+ }
212
+
213
+ func (srv * Server ) GetChannelHandler (kind string ) ChannelHandler {
214
+ srv .ensureHandlers ()
215
+ return srv .channelHandlers [kind ]
216
+ }
217
+
209
218
func (srv * Server ) handleConn (newConn net.Conn ) {
210
219
if srv .ConnCallback != nil {
211
220
cbConn := srv .ConnCallback (newConn )
@@ -240,7 +249,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
240
249
go srv .handleRequests (ctx , reqs )
241
250
for ch := range chans {
242
251
handler , found := srv .channelHandlers [ch .ChannelType ()]
243
- if ! found {
252
+ if ! found || handler == nil {
253
+ if defaultHandler , found := srv .channelHandlers ["default" ]; found {
254
+ handler = defaultHandler
255
+ }
256
+ }
257
+ if handler == nil {
244
258
ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
245
259
continue
246
260
}
0 commit comments