Skip to content

Commit 3eeacb7

Browse files
authored
session: adding signal handling support (#44)
1 parent 4a4de39 commit 3eeacb7

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

session.go

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"fmt"
88
"net"
9+
"sync"
910

1011
"github.com/anmitsu/go-shlex"
1112
gossh "golang.org/x/crypto/ssh"
@@ -63,9 +64,19 @@ type Session interface {
6364
// of whether or not a PTY was accepted for this session.
6465
Pty() (Pty, <-chan Window, bool)
6566

66-
// TODO: Signals(c chan<- Signal)
67+
// Signals registers a channel to receive signals sent from the client. The
68+
// channel must handle signal sends or it will block the SSH request loop.
69+
// Registering nil will unregister the channel from signal sends. During the
70+
// time no channel is registered signals are buffered up to a reasonable amount.
71+
// If there are buffered signals when a channel is registered, they will be
72+
// sent in order on the channel immediately after registering.
73+
Signals(c chan<- Signal)
6774
}
6875

76+
// maxSigBufSize is how many signals will be buffered
77+
// when there is no signal channel specified
78+
const maxSigBufSize = 128
79+
6980
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
7081
ch, reqs, err := newChan.Accept()
7182
if err != nil {
@@ -83,6 +94,7 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
8394
}
8495

8596
type session struct {
97+
sync.Mutex
8698
gossh.Channel
8799
conn *gossh.ServerConn
88100
handler Handler
@@ -94,6 +106,8 @@ type session struct {
94106
ptyCb PtyCallback
95107
cmd []string
96108
ctx *sshContext
109+
sigCh chan<- Signal
110+
sigBuf []Signal
97111
}
98112

99113
func (sess *session) Write(p []byte) (n int, err error) {
@@ -132,6 +146,8 @@ func (sess *session) Context() context.Context {
132146
}
133147

134148
func (sess *session) Exit(code int) error {
149+
sess.Lock()
150+
defer sess.Unlock()
135151
if sess.exited {
136152
return errors.New("Session.Exit called multiple times")
137153
}
@@ -172,6 +188,19 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) {
172188
return Pty{}, sess.winch, false
173189
}
174190

191+
func (sess *session) Signals(c chan<- Signal) {
192+
sess.Lock()
193+
defer sess.Unlock()
194+
sess.sigCh = c
195+
if len(sess.sigBuf) > 0 {
196+
go func() {
197+
for _, sig := range sess.sigBuf {
198+
sess.sigCh <- sig
199+
}
200+
}()
201+
}
202+
}
203+
175204
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
176205
for req := range reqs {
177206
switch req.Type {
@@ -195,10 +224,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
195224
req.Reply(false, nil)
196225
continue
197226
}
198-
var kv = struct{ Key, Value string }{}
227+
var kv struct{ Key, Value string }
199228
gossh.Unmarshal(req.Payload, &kv)
200229
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
201230
req.Reply(true, nil)
231+
case "signal":
232+
var payload struct{ Signal string }
233+
gossh.Unmarshal(req.Payload, &payload)
234+
sess.Lock()
235+
if sess.sigCh != nil {
236+
sess.sigCh <- Signal(payload.Signal)
237+
} else {
238+
if len(sess.sigBuf) < maxSigBufSize {
239+
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
240+
}
241+
}
242+
sess.Unlock()
202243
case "pty-req":
203244
if sess.handled || sess.pty != nil {
204245
req.Reply(false, nil)

session_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,35 @@ func TestPtyResize(t *testing.T) {
280280
session.Close()
281281
<-done
282282
}
283+
284+
func TestSignals(t *testing.T) {
285+
t.Parallel()
286+
287+
session, _, cleanup := newTestSession(t, &Server{
288+
Handler: func(s Session) {
289+
signals := make(chan Signal)
290+
s.Signals(signals)
291+
if sig := <-signals; sig != SIGINT {
292+
t.Fatalf("expected signal %v but got %v", SIGINT, sig)
293+
}
294+
exiter := make(chan bool)
295+
go func() {
296+
if sig := <-signals; sig == SIGKILL {
297+
close(exiter)
298+
}
299+
}()
300+
<-exiter
301+
},
302+
}, nil)
303+
defer cleanup()
304+
305+
go func() {
306+
session.Signal(gossh.SIGINT)
307+
session.Signal(gossh.SIGKILL)
308+
}()
309+
310+
err := session.Run("")
311+
if err != nil {
312+
t.Fatalf("expected nil but got %v", err)
313+
}
314+
}

0 commit comments

Comments
 (0)