diff --git a/share/tunnel/tunnel.go b/share/tunnel/tunnel.go index c22b737a..9cc1216d 100644 --- a/share/tunnel/tunnel.go +++ b/share/tunnel/tunnel.go @@ -173,18 +173,44 @@ func (t *Tunnel) BindRemotes(ctx context.Context, remotes []*settings.Remote) er } func (t *Tunnel) keepAliveLoop(sshConn ssh.Conn) { - //ping forever - for { - time.Sleep(t.Config.KeepAlive) - _, b, err := sshConn.SendRequest("ping", true, nil) - if err != nil { - break - } - if len(b) > 0 && !bytes.Equal(b, []byte("pong")) { - t.Debugf("strange ping response") - break - } - } - //close ssh connection on abnormal ping - sshConn.Close() + // ping forever with a timeout + PingCheckOLoop: + for { + time.Sleep(t.Config.KeepAlive) + + ctx, cancel := context.WithTimeout(context.Background(), t.Config.KeepAlive) + defer cancel() + + responseCh := make(chan []byte, 1) + errCh := make(chan error, 1) + + // Asynchronously send a 'ping' request via SSH + go func() { + _, b, err := sshConn.SendRequest("ping", true, nil) + if err != nil { + errCh <- err + return + } + responseCh <- b + }() + + // Wait for a response, error, or timeout from the asynchronous 'ping' request + select { + case response := <-responseCh: + if len(response) > 0 && !bytes.Equal(response, []byte("pong")) { + t.Debugf("Unexpected ping response: %s", response) + break PingCheckOLoop + } + case err := <-errCh: + if err != nil { + t.Debugf("Failed to send ping: %s", err) + break PingCheckOLoop + } + case <-ctx.Done(): + t.Debugf("Ping timed out") + break PingCheckOLoop + } + } + // Close the SSH connection on abnormal ping + sshConn.Close() }