@@ -37,8 +37,15 @@ type Server struct {
37
37
IdleTimeout time.Duration // connection timeout when no activity, none if empty
38
38
MaxTimeout time.Duration // absolute connection timeout, none if empty
39
39
40
- channelHandlers map [string ]ChannelHandler // fallback channel handlers
41
- requestHandlers map [string ]RequestHandler
40
+ // ChannelHandlers allow overriding the built-in session handlers or provide
41
+ // extensions to the protocol, such as tcpip forwarding. By default only the
42
+ // "session" handler is enabled.
43
+ ChannelHandlers map [string ]ChannelHandler
44
+
45
+ // RequestHandlers allow overriding the server-level request handlers or
46
+ // provide extensions to the protocol, such as tcpip forwarding. By default
47
+ // no handlers are enabled.
48
+ RequestHandlers map [string ]RequestHandler
42
49
43
50
listenerWg sync.WaitGroup
44
51
mu sync.Mutex
@@ -82,14 +89,14 @@ func (srv *Server) ensureHostSigner() error {
82
89
func (srv * Server ) ensureHandlers () {
83
90
srv .mu .Lock ()
84
91
defer srv .mu .Unlock ()
85
- if srv .requestHandlers == nil {
86
- srv .requestHandlers = map [string ]RequestHandler {
92
+ if srv .RequestHandlers == nil {
93
+ srv .RequestHandlers = map [string ]RequestHandler {
87
94
"tcpip-forward" : forwardedTCPHandler {},
88
95
"cancel-tcpip-forward" : forwardedTCPHandler {},
89
96
}
90
97
}
91
- if srv .channelHandlers == nil {
92
- srv .channelHandlers = map [string ]ChannelHandler {
98
+ if srv .ChannelHandlers == nil {
99
+ srv .ChannelHandlers = map [string ]ChannelHandler {
93
100
"session" : ChannelHandlerFunc (sessionHandler ),
94
101
"direct-tcpip" : ChannelHandlerFunc (directTcpipHandler ),
95
102
}
@@ -234,18 +241,6 @@ func (srv *Server) Serve(l net.Listener) error {
234
241
}
235
242
}
236
243
237
- func (srv * Server ) SetChannelHandler (kind string , handler ChannelHandler ) {
238
- srv .ensureHandlers ()
239
- srv .mu .Lock ()
240
- defer srv .mu .Unlock ()
241
- srv .channelHandlers [kind ] = handler
242
- }
243
-
244
- func (srv * Server ) ChannelHandler (kind string ) ChannelHandler {
245
- srv .ensureHandlers ()
246
- return srv .channelHandlers [kind ]
247
- }
248
-
249
244
func (srv * Server ) handleConn (newConn net.Conn ) {
250
245
if srv .ConnCallback != nil {
251
246
cbConn := srv .ConnCallback (newConn )
@@ -279,11 +274,9 @@ func (srv *Server) handleConn(newConn net.Conn) {
279
274
//go gossh.DiscardRequests(reqs)
280
275
go srv .handleRequests (ctx , reqs )
281
276
for ch := range chans {
282
- handler , found := srv .channelHandlers [ch .ChannelType ()]
283
- if ! found || handler == nil {
284
- if defaultHandler , found := srv .channelHandlers ["default" ]; found {
285
- handler = defaultHandler
286
- }
277
+ handler := srv .ChannelHandlers [ch .ChannelType ()]
278
+ if handler == nil {
279
+ handler = srv .ChannelHandlers ["default" ]
287
280
}
288
281
if handler == nil {
289
282
ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
@@ -295,8 +288,11 @@ func (srv *Server) handleConn(newConn net.Conn) {
295
288
296
289
func (srv * Server ) handleRequests (ctx Context , in <- chan * gossh.Request ) {
297
290
for req := range in {
298
- handler , found := srv .requestHandlers [req .Type ]
299
- if ! found {
291
+ handler := srv .RequestHandlers [req .Type ]
292
+ if handler == nil {
293
+ handler = srv .RequestHandlers ["default" ]
294
+ }
295
+ if handler == nil {
300
296
if req .WantReply {
301
297
req .Reply (false , nil )
302
298
}
0 commit comments