diff --git a/share/tunnel/tunnel_in_proxy.go b/share/tunnel/tunnel_in_proxy.go index 007fb0c7..1037e52a 100644 --- a/share/tunnel/tunnel_in_proxy.go +++ b/share/tunnel/tunnel_in_proxy.go @@ -10,6 +10,7 @@ import ( "github.com/jpillora/chisel/share/settings" "github.com/jpillora/sizestr" "golang.org/x/crypto/ssh" + "errors" ) //sshTunnel exposes a subset of Tunnel to subtypes @@ -118,13 +119,39 @@ func (p *Proxy) runTCP(ctx context.Context) error { close(done) return err } - go p.pipeRemote(ctx, src) + + dst, l, err := p.openSshChannel(ctx) + if err != nil { + // SSH channel failed - likely because remote end not connectable + l.Infof("Reset") + resetTCP(src.(*net.TCPConn), err.Error()) + continue + } + + go p.pipeRemoteSshAlreadyOpened(src, dst, l) } } func (p *Proxy) pipeRemote(ctx context.Context, src io.ReadWriteCloser) { defer src.Close() + dst, l, err := p.openSshChannel(ctx) + if err != nil { + return + } + + //then pipe + s, r := cio.Pipe(src, dst) + l.Debugf("Close (sent %s received %s)", sizestr.ToString(s), sizestr.ToString(r)) +} + +func (p *Proxy) pipeRemoteSshAlreadyOpened(src, dst io.ReadWriteCloser, l *cio.Logger) { + defer src.Close() + s, r := cio.Pipe(src, dst) + l.Debugf("Close (sent %s received %s)", sizestr.ToString(s), sizestr.ToString(r)) +} + +func (p *Proxy) openSshChannel(ctx context.Context) (io.ReadWriteCloser, *cio.Logger, error) { p.mu.Lock() p.count++ cid := p.count @@ -132,19 +159,32 @@ func (p *Proxy) pipeRemote(ctx context.Context, src io.ReadWriteCloser) { l := p.Fork("conn#%d", cid) l.Debugf("Open") + sshConn := p.sshTun.getSSH(ctx) if sshConn == nil { l.Debugf("No remote connection") - return + return nil, l, errors.New("No remote SSH connection") } //ssh request for tcp connection for this proxy's remote dst, reqs, err := sshConn.OpenChannel("chisel", []byte(p.remote.Remote())) if err != nil { l.Infof("Stream error: %s", err) - return + return nil, l, err } go ssh.DiscardRequests(reqs) - //then pipe - s, r := cio.Pipe(src, dst) - l.Debugf("Close (sent %s received %s)", sizestr.ToString(s), sizestr.ToString(r)) + return dst, l, nil } + +func resetTCP(c *net.TCPConn, msg string) { + data := append([]byte(msg),0x0a) + + c.SetLinger(0) // TCP reset if close and all data not yet sent/acked + c.SetNoDelay(false) // Make TCP slower to send... + // maximise changes to have Write() and Close() executed without gorouting switch + func() { + c.Write(data) // ignore errors + c.Close() + }() + + return +} \ No newline at end of file diff --git a/share/tunnel/tunnel_out_ssh.go b/share/tunnel/tunnel_out_ssh.go index b07b98ed..aa3df2a9 100644 --- a/share/tunnel/tunnel_out_ssh.go +++ b/share/tunnel/tunnel_out_ssh.go @@ -31,6 +31,9 @@ func (t *Tunnel) handleSSHChannels(chans <-chan ssh.NewChannel) { } func (t *Tunnel) handleSSHChannel(ch ssh.NewChannel) { + var dst net.Conn + var err error + if !t.Config.Outbound { t.Debugf("Denied outbound connection") ch.Reject(ssh.Prohibited, "Denied outbound connection") @@ -41,11 +44,21 @@ func (t *Tunnel) handleSSHChannel(ch ssh.NewChannel) { hostPort, proto := settings.L4Proto(remote) udp := proto == "udp" socks := hostPort == "socks" + tcp := !(udp || socks) if socks && t.socksServer == nil { t.Debugf("Denied socks request, please enable socks") ch.Reject(ssh.Prohibited, "SOCKS5 is not enabled") return } + if tcp { + dst, err = net.Dial("tcp", hostPort) + if err != nil { + msg := fmt.Sprintf("Failed to connect to %s: %s", hostPort, err) + t.Debugf(msg) + ch.Reject(ssh.ConnectionFailed, msg) + return + } + } sshChan, reqs, err := ch.Accept() if err != nil { t.Debugf("Failed to accept stream: %s", err) @@ -64,7 +77,7 @@ func (t *Tunnel) handleSSHChannel(ch ssh.NewChannel) { } else if udp { err = t.handleUDP(l, stream, hostPort) } else { - err = t.handleTCP(l, stream, hostPort) + err = t.handleTCP(l, stream, dst) } t.connStats.Close() errmsg := "" @@ -78,11 +91,9 @@ func (t *Tunnel) handleSocks(src io.ReadWriteCloser) error { return t.socksServer.ServeConn(cnet.NewRWCConn(src)) } -func (t *Tunnel) handleTCP(l *cio.Logger, src io.ReadWriteCloser, hostPort string) error { - dst, err := net.Dial("tcp", hostPort) - if err != nil { - return err - } +func (t *Tunnel) handleTCP(l *cio.Logger, src, dst io.ReadWriteCloser) error { + // No need to do it in Pipe() when CloseWrite() is used + defer dst.Close() s, r := cio.Pipe(src, dst) l.Debugf("sent %s received %s", sizestr.ToString(s), sizestr.ToString(r)) return nil