Skip to content

Commit 53c62f1

Browse files
authored
Merge pull request #632 from pkg/fix-ssh-client-use
Fix SSH subsystemrequest usage
2 parents 36ce9cf + c7176b3 commit 53c62f1

File tree

7 files changed

+120
-34
lines changed

7 files changed

+120
-34
lines changed

client.go

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption {
158158
}
159159
}
160160

161+
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
162+
//
163+
// The writer passed in will not be automatically closed.
164+
// It is the responsibility of the caller to coordinate closure of any writers.
165+
func CopyStderrTo(wr io.Writer) ClientOption {
166+
return func(c *Client) error {
167+
c.stderrTo = wr
168+
return nil
169+
}
170+
}
171+
161172
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
162173
// Multiple Clients can be active on a single SSH connection, and a Client
163174
// may be called concurrently from multiple Goroutines.
@@ -166,6 +177,8 @@ func UseFstat(value bool) ClientOption {
166177
type Client struct {
167178
clientConn
168179

180+
stderrTo io.Writer
181+
169182
ext map[string]string // Extensions (name -> data).
170183

171184
maxPacket int // max packet size read or written.
@@ -186,9 +199,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
186199
if err != nil {
187200
return nil, err
188201
}
189-
if err := s.RequestSubsystem("sftp"); err != nil {
190-
return nil, err
191-
}
202+
192203
pw, err := s.StdinPipe()
193204
if err != nil {
194205
return nil, err
@@ -197,22 +208,35 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
197208
if err != nil {
198209
return nil, err
199210
}
211+
perr, err := s.StderrPipe()
212+
if err != nil {
213+
return nil, err
214+
}
200215

201-
return NewClientPipe(pr, pw, opts...)
216+
if err := s.RequestSubsystem("sftp"); err != nil {
217+
return nil, err
218+
}
219+
220+
return newClientPipe(pr, perr, pw, s.Wait, opts...)
202221
}
203222

204223
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
205224
// This can be used for connecting to an SFTP server over TCP/TLS or by using
206225
// the system's ssh client program (e.g. via exec.Command).
207226
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
208-
sftp := &Client{
227+
return newClientPipe(rd, nil, wr, nil, opts...)
228+
}
229+
230+
func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) {
231+
c := &Client{
209232
clientConn: clientConn{
210233
conn: conn{
211234
Reader: rd,
212235
WriteCloser: wr,
213236
},
214237
inflight: make(map[uint32]chan<- result),
215238
closed: make(chan struct{}),
239+
wait: wait,
216240
},
217241

218242
ext: make(map[string]string),
@@ -222,32 +246,50 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
222246
}
223247

224248
for _, opt := range opts {
225-
if err := opt(sftp); err != nil {
249+
if err := opt(c); err != nil {
226250
wr.Close()
227251
return nil, err
228252
}
229253
}
230254

231-
if err := sftp.sendInit(); err != nil {
255+
if stderr != nil {
256+
wr := io.Discard
257+
if c.stderrTo != nil {
258+
wr = c.stderrTo
259+
}
260+
261+
go func() {
262+
// DO NOT close the writer!
263+
// Programs may pass in `os.Stderr` to write the remote stderr to,
264+
// and the program may continue after disconnect by reconnecting.
265+
// But if we've closed their stderr, then we just messed everything up.
266+
267+
if _, err := io.Copy(wr, stderr); err != nil {
268+
debug("error copying stderr: %v", err)
269+
}
270+
}()
271+
}
272+
273+
if err := c.sendInit(); err != nil {
232274
wr.Close()
233275
return nil, fmt.Errorf("error sending init packet to server: %w", err)
234276
}
235277

236-
if err := sftp.recvVersion(); err != nil {
278+
if err := c.recvVersion(); err != nil {
237279
wr.Close()
238280
return nil, fmt.Errorf("error receiving version packet from server: %w", err)
239281
}
240282

241-
sftp.clientConn.wg.Add(1)
283+
c.clientConn.wg.Add(1)
242284
go func() {
243-
defer sftp.clientConn.wg.Done()
285+
defer c.clientConn.wg.Done()
244286

245-
if err := sftp.clientConn.recv(); err != nil {
246-
sftp.clientConn.broadcastErr(err)
287+
if err := c.clientConn.recv(); err != nil {
288+
c.clientConn.broadcastErr(err)
247289
}
248290
}()
249291

250-
return sftp, nil
292+
return c, nil
251293
}
252294

253295
// Create creates the named file mode 0666 (before umask), truncating it if it

conn.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type conn struct {
2222
// For the client mode just pass 0.
2323
// It returns io.EOF if the connection is closed and
2424
// there are no more packets to read.
25-
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
25+
func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) {
2626
return recvPacket(c, c.alloc, orderID)
2727
}
2828

@@ -43,6 +43,8 @@ type clientConn struct {
4343
conn
4444
wg sync.WaitGroup
4545

46+
wait func() error // if non-nil, call this during Wait() to get a possible remote status error.
47+
4648
sync.Mutex // protects inflight
4749
inflight map[uint32]chan<- result // outstanding requests
4850

@@ -55,6 +57,27 @@ type clientConn struct {
5557
// goroutines.
5658
func (c *clientConn) Wait() error {
5759
<-c.closed
60+
61+
if c.wait == nil {
62+
// Only return this error if c.wait won't return something more useful.
63+
return c.err
64+
}
65+
66+
if err := c.wait(); err != nil {
67+
68+
// TODO: when https://github.com/golang/go/issues/35025 is fixed,
69+
// we can remove this if block entirely.
70+
// Right now, it’s always going to return this, so it is not useful.
71+
// But we have this code here so that as soon as the ssh library is updated,
72+
// we can return a possibly more useful error.
73+
if err.Error() == "ssh: session not started" {
74+
return c.err
75+
}
76+
77+
return err
78+
}
79+
80+
// c.wait returned no error; so, let's return something maybe more useful.
5881
return c.err
5982
}
6083

@@ -119,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
119142

120143
// result captures the result of receiving the a packet from the server
121144
type result struct {
122-
typ byte
145+
typ fxp
123146
data []byte
124147
err error
125148
}
@@ -129,7 +152,7 @@ type idmarshaler interface {
129152
encoding.BinaryMarshaler
130153
}
131154

132-
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
155+
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) {
133156
if cap(ch) < 1 {
134157
ch = make(chan result, 1)
135158
}

packet.go

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,22 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
304304
return nil
305305
}
306306

307-
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) {
307+
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) {
308308
var b []byte
309309
if alloc != nil {
310310
b = alloc.GetPage(orderID)
311311
} else {
312312
b = make([]byte, 4)
313313
}
314-
if _, err := io.ReadFull(r, b[:4]); err != nil {
315-
return 0, nil, err
314+
315+
if n, err := io.ReadFull(r, b[:4]); err != nil {
316+
if err == io.EOF {
317+
return 0, nil, err
318+
}
319+
320+
return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err)
316321
}
322+
317323
length, _ := unmarshalUint32(b)
318324
if length > maxMsgLength {
319325
debug("recv packet %d bytes too long", length)
@@ -323,24 +329,39 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
323329
debug("recv packet of 0 bytes too short")
324330
return 0, nil, errShortPacket
325331
}
332+
326333
if alloc == nil {
327334
b = make([]byte, length)
328335
}
329-
if _, err := io.ReadFull(r, b[:length]); err != nil {
336+
337+
n, err := io.ReadFull(r, b[:length])
338+
b = b[:n]
339+
340+
if err != nil {
341+
debug("recv packet error: %d of %d bytes: %x", n, length, b)
342+
330343
// ReadFull only returns EOF if it has read no bytes.
331344
// In this case, that means a partial packet, and thus unexpected.
332345
if err == io.EOF {
333346
err = io.ErrUnexpectedEOF
334347
}
335-
debug("recv packet %d bytes: err %v", length, err)
336-
return 0, nil, err
348+
349+
if n == 0 {
350+
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err)
351+
}
352+
353+
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err)
337354
}
355+
356+
typ, payload := fxp(b[0]), b[1:n]
357+
338358
if debugDumpRxPacketBytes {
339-
debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length])
359+
debug("recv packet: %s %d bytes %x", typ, length, payload)
340360
} else if debugDumpRxPacket {
341-
debug("recv packet: %s %d bytes", fxp(b[0]), length)
361+
debug("recv packet: %s %d bytes", typ, length)
342362
}
343-
return b[0], b[1:length], nil
363+
364+
return typ, payload, nil
344365
}
345366

346367
type extensionPair struct {

packet_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) {
468468
var recvPacketTests = []struct {
469469
b []byte
470470

471-
want uint8
471+
want fxp
472472
body []byte
473473
wantErr error
474474
}{

request-server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
148148

149149
var err error
150150
var pkt requestPacket
151-
var pktType uint8
151+
var pktType fxp
152152
var pktBytes []byte
153153

154154
for {
@@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
158158
return err
159159
}
160160

161-
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
161+
pkt, err = makePacket(rxPacket{pktType, pktBytes})
162162
if err != nil {
163163
switch {
164164
case errors.Is(err, errUnknownExtendedPacket):

server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ func (svr *Server) Serve() error {
390390

391391
var err error
392392
var pkt requestPacket
393-
var pktType uint8
393+
var pktType fxp
394394
var pktBytes []byte
395395
for {
396396
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
@@ -403,7 +403,7 @@ func (svr *Server) Serve() error {
403403
break
404404
}
405405

406-
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
406+
pkt, err = makePacket(rxPacket{pktType, pktBytes})
407407
if err != nil {
408408
switch {
409409
case errors.Is(err, errUnknownExtendedPacket):

sftp.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,15 @@ func (f fx) String() string {
184184
}
185185

186186
type unexpectedPacketErr struct {
187-
want, got uint8
187+
want, got fxp
188188
}
189189

190190
func (u *unexpectedPacketErr) Error() string {
191-
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got))
191+
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got)
192192
}
193193

194-
func unimplementedPacketErr(u uint8) error {
195-
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
194+
func unimplementedPacketErr(u fxp) error {
195+
return fmt.Errorf("sftp: unimplemented packet type: got %v", u)
196196
}
197197

198198
type unexpectedIDErr struct{ want, got uint32 }

0 commit comments

Comments
 (0)