Skip to content

Commit 62448dd

Browse files
committed
feat: safe stderr
1 parent 0f80af4 commit 62448dd

File tree

4 files changed

+98
-14
lines changed

4 files changed

+98
-14
lines changed

pty.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package ssh
2+
3+
import (
4+
"bytes"
5+
"io"
6+
)
7+
8+
// NewPtyWriter creates a writer that handles when the session has a active
9+
// PTY, replacing the \n with \r\n.
10+
func NewPtyWriter(w io.Writer) io.Writer {
11+
return ptyWriter{
12+
w: w,
13+
}
14+
}
15+
16+
type ptyWriter struct {
17+
w io.Writer
18+
}
19+
20+
func (w ptyWriter) Write(p []byte) (int, error) {
21+
m := len(p)
22+
// normalize \n to \r\n when pty is accepted.
23+
// this is a hardcoded shortcut since we don't support terminal modes.
24+
p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1)
25+
p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1)
26+
n, err := w.w.Write(p)
27+
if n > m {
28+
n = m
29+
}
30+
return n, err
31+
}

pty_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package ssh_test
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/gliderlabs/ssh"
8+
)
9+
10+
func TestNewPtyWriter(t *testing.T) {
11+
in := "\nfoo\r\nbar\nmore text\rmore\r\r\r\nfoo\n\n"
12+
out := "\r\nfoo\r\nbar\r\nmore text\rmore\r\r\r\nfoo\r\n\r\n"
13+
var b bytes.Buffer
14+
n, err := ssh.NewPtyWriter(&b).Write([]byte(in))
15+
if err != nil {
16+
t.Error("did not expect an error", err)
17+
}
18+
if out != b.String() {
19+
t.Errorf("outputs do not match, expected %q got %q", out, b.String())
20+
}
21+
if n != len(in) {
22+
t.Errorf("expected to write %d bytes, wrote %d", len(in), n)
23+
}
24+
}

session.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package ssh
22

33
import (
4-
"bytes"
54
"errors"
65
"fmt"
6+
"io"
77
"net"
88
"sync"
99

@@ -82,6 +82,10 @@ type Session interface {
8282
// the request handling loop. Registering nil will unregister the channel.
8383
// During the time that no channel is registered, breaks are ignored.
8484
Break(c chan<- bool)
85+
86+
// SafeStderr returns the Stderr io.Writer that handles replacing \n with
87+
// \r\n when there's an active Pty.
88+
SafeStderr() io.Writer
8589
}
8690

8791
// maxSigBufSize is how many signals will be buffered
@@ -127,18 +131,16 @@ type session struct {
127131
breakCh chan<- bool
128132
}
129133

130-
func (sess *session) Write(p []byte) (n int, err error) {
134+
func (sess *session) SafeStderr() io.Writer {
131135
if sess.pty != nil {
132-
m := len(p)
133-
// normalize \n to \r\n when pty is accepted.
134-
// this is a hardcoded shortcut since we don't support terminal modes.
135-
p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1)
136-
p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1)
137-
n, err = sess.Channel.Write(p)
138-
if n > m {
139-
n = m
140-
}
141-
return
136+
return NewPtyWriter(sess.Stderr())
137+
}
138+
return sess.Stderr()
139+
}
140+
141+
func (sess *session) Write(p []byte) (int, error) {
142+
if sess.pty != nil {
143+
return NewPtyWriter(sess.Channel).Write(p)
142144
}
143145
return sess.Channel.Write(p)
144146
}
@@ -242,7 +244,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
242244
continue
243245
}
244246

245-
var payload = struct{ Value string }{}
247+
payload := struct{ Value string }{}
246248
gossh.Unmarshal(req.Payload, &payload)
247249
sess.rawCmd = payload.Value
248250

@@ -267,7 +269,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
267269
continue
268270
}
269271

270-
var payload = struct{ Value string }{}
272+
payload := struct{ Value string }{}
271273
gossh.Unmarshal(req.Payload, &payload)
272274
sess.subsystem = payload.Value
273275

session_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,33 @@ func TestPty(t *testing.T) {
228228
<-done
229229
}
230230

231+
func TestPtyWriter(t *testing.T) {
232+
t.Parallel()
233+
term := "xterm"
234+
winWidth := 40
235+
winHeight := 80
236+
session, _, cleanup := newTestSession(t, &Server{
237+
Handler: func(s Session) {
238+
_, _ = fmt.Fprintln(s, "foo\nbar")
239+
_, _ = fmt.Fprintln(s.SafeStderr(), "many\nerrors")
240+
_ = s.Exit(0)
241+
},
242+
}, nil)
243+
defer cleanup()
244+
if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil {
245+
t.Fatalf("expected nil but got %v", err)
246+
}
247+
bts, err := session.CombinedOutput("")
248+
if err != nil {
249+
t.Fatalf("expected nil but got %v", err)
250+
}
251+
252+
expected := "foo\r\nbar\r\nmany\r\nerrors\r\n"
253+
if expected != string(bts) {
254+
t.Fatalf("expected output to be %q, got %q", expected, string(bts))
255+
}
256+
}
257+
231258
func TestPtyResize(t *testing.T) {
232259
t.Parallel()
233260
winch0 := Window{40, 80}

0 commit comments

Comments
 (0)