Skip to content

Commit 993008d

Browse files
committed
refactor: keeping the same Stderr method
1 parent 62448dd commit 993008d

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

pty.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ func NewPtyWriter(w io.Writer) io.Writer {
1313
}
1414
}
1515

16+
var _ io.Writer = ptyWriter{}
17+
1618
type ptyWriter struct {
1719
w io.Writer
1820
}
@@ -29,3 +31,27 @@ func (w ptyWriter) Write(p []byte) (int, error) {
2931
}
3032
return n, err
3133
}
34+
35+
// NewPtyReadWriter return an io.ReadWriter that delegates the read to the
36+
// given io.ReadWriter, and the writes to a ptyWriter.
37+
func NewPtyReadWriter(rw io.ReadWriter) io.ReadWriter {
38+
return readWriterDelegate{
39+
w: NewPtyWriter(rw),
40+
r: rw,
41+
}
42+
}
43+
44+
var _ io.ReadWriter = readWriterDelegate{}
45+
46+
type readWriterDelegate struct {
47+
w io.Writer
48+
r io.Reader
49+
}
50+
51+
func (rw readWriterDelegate) Read(p []byte) (n int, err error) {
52+
return rw.r.Read(p)
53+
}
54+
55+
func (rw readWriterDelegate) Write(p []byte) (n int, err error) {
56+
return rw.w.Write(p)
57+
}

session.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ type Session interface {
8383
// During the time that no channel is registered, breaks are ignored.
8484
Break(c chan<- bool)
8585

86-
// SafeStderr returns the Stderr io.Writer that handles replacing \n with
87-
// \r\n when there's an active Pty.
88-
SafeStderr() io.Writer
86+
// Stderr returns an io.ReadWriter that writes to this channel
87+
// with the extended data type set to stderr. Stderr may
88+
// safely be read and written from a different goroutine than
89+
// Read and Write respectively.
90+
Stderr() io.ReadWriter
8991
}
9092

9193
// maxSigBufSize is how many signals will be buffered
@@ -131,11 +133,11 @@ type session struct {
131133
breakCh chan<- bool
132134
}
133135

134-
func (sess *session) SafeStderr() io.Writer {
136+
func (sess *session) Stderr() io.ReadWriter {
135137
if sess.pty != nil {
136-
return NewPtyWriter(sess.Stderr())
138+
return NewPtyReadWriter(sess.Channel.Stderr())
137139
}
138-
return sess.Stderr()
140+
return sess.Channel.Stderr()
139141
}
140142

141143
func (sess *session) Write(p []byte) (int, error) {

session_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77
"net"
88
"testing"
9+
"time"
910

1011
gossh "golang.org/x/crypto/ssh"
1112
)
@@ -236,7 +237,8 @@ func TestPtyWriter(t *testing.T) {
236237
session, _, cleanup := newTestSession(t, &Server{
237238
Handler: func(s Session) {
238239
_, _ = fmt.Fprintln(s, "foo\nbar")
239-
_, _ = fmt.Fprintln(s.SafeStderr(), "many\nerrors")
240+
time.Sleep(10 * time.Millisecond)
241+
_, _ = fmt.Fprintln(s.Stderr(), "many\nerrors")
240242
_ = s.Exit(0)
241243
},
242244
}, nil)

0 commit comments

Comments
 (0)