diff --git a/share/cio/pipe.go b/share/cio/pipe.go index e32fc44e..f0bcf4b5 100644 --- a/share/cio/pipe.go +++ b/share/cio/pipe.go @@ -6,23 +6,31 @@ import ( "sync" ) +type ReadWriteWriterCloser interface { + io.ReadWriteCloser + CloseWrite() error +} + func Pipe(src io.ReadWriteCloser, dst io.ReadWriteCloser) (int64, int64) { var sent, received int64 var wg sync.WaitGroup - var o sync.Once - close := func() { - src.Close() - dst.Close() - } wg.Add(2) go func() { - received, _ = io.Copy(src, dst) - o.Do(close) + received, _ = io.Copy(dst, src) + if dst2, ok := dst.(ReadWriteWriterCloser); ok { + dst2.CloseWrite() + } else { + dst.Close() + } wg.Done() }() go func() { - sent, _ = io.Copy(dst, src) - o.Do(close) + sent, _ = io.Copy(src, dst) + if src2, ok := src.(ReadWriteWriterCloser); ok { + src2.CloseWrite() + } else { + src.Close() + } wg.Done() }() wg.Wait() diff --git a/share/tunnel/tunnel_in_proxy.go b/share/tunnel/tunnel_in_proxy.go index 007fb0c7..e49febc3 100644 --- a/share/tunnel/tunnel_in_proxy.go +++ b/share/tunnel/tunnel_in_proxy.go @@ -143,6 +143,8 @@ func (p *Proxy) pipeRemote(ctx context.Context, src io.ReadWriteCloser) { l.Infof("Stream error: %s", err) return } + // No need to do it in Pipe() when CloseWrite() is used + defer dst.Close() go ssh.DiscardRequests(reqs) //then pipe s, r := cio.Pipe(src, dst) diff --git a/share/tunnel/tunnel_out_ssh.go b/share/tunnel/tunnel_out_ssh.go index b07b98ed..1424aaf3 100644 --- a/share/tunnel/tunnel_out_ssh.go +++ b/share/tunnel/tunnel_out_ssh.go @@ -83,6 +83,8 @@ func (t *Tunnel) handleTCP(l *cio.Logger, src io.ReadWriteCloser, hostPort strin if err != nil { return err } + // No need to do it in Pipe() when CloseWrite() is used + defer dst.Close() s, r := cio.Pipe(src, dst) l.Debugf("sent %s received %s", sizestr.ToString(s), sizestr.ToString(r)) return nil