Skip to content

Commit ccf9dfd

Browse files
authored
refactor(ssh): improve thread safety and shutdown handling in ssh client (#160)
* refactor(ssh): improve thread safety and shutdown handling in ssh client - Replace ReleaseTerminal with Kill in client.go for more forceful cleanup - Add mutex protection for shared resources and atomic flags for state management - Implement proper shutdown sequence with context support - Add nil checks for TUI operations to prevent panics - Improve reconnect logic with timeout and atomic operation protection * feat(ssh): add immediate health status update on successful reconnect
1 parent ebfc1ac commit ccf9dfd

File tree

2 files changed

+120
-32
lines changed

2 files changed

+120
-32
lines changed

tunnel/internal/client/client/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@ func NewClient(config *config.Config, db *db.Db) *Client {
2626
go func() {
2727
defer func() {
2828
if r := recover(); r != nil {
29-
_ = p.ReleaseTerminal()
29+
p.Kill()
3030
fmt.Printf("Recovered from panic: %v\n", r)
3131
os.Exit(1)
3232
}
3333
}()
3434

3535
if _, err := p.Run(); err != nil {
36-
_ = p.ReleaseTerminal()
36+
p.Kill()
3737
fmt.Printf("Failed to run TUI: %v\n", err)
3838
os.Exit(1)
3939
}
4040

41-
_ = p.ReleaseTerminal()
41+
p.Kill()
4242
os.Exit(0)
4343
}()
4444

tunnel/internal/client/ssh/ssh.go

Lines changed: 117 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"net"
1111
"net/http"
1212
"os"
13+
"sync"
14+
"sync/atomic"
1315
"time"
1416

1517
"github.com/amalshaji/portr/internal/client/config"
@@ -32,11 +34,14 @@ var (
3234
)
3335

3436
type SshClient struct {
35-
config config.ClientConfig
36-
listener net.Listener
37-
db *db.Db
38-
client *ssh.Client
39-
tui *tea.Program
37+
config config.ClientConfig
38+
listener net.Listener
39+
db *db.Db
40+
client *ssh.Client
41+
tui *tea.Program
42+
mu sync.RWMutex
43+
reconnecting int32 // atomic flag to prevent concurrent reconnects
44+
shutdown int32 // atomic flag for shutdown state
4045
}
4146

4247
func New(config config.ClientConfig, db *db.Db, tui *tea.Program) *SshClient {
@@ -87,6 +92,11 @@ func (s *SshClient) createNewConnection() (string, error) {
8792
}
8893

8994
func (s *SshClient) startListenerForClient() error {
95+
// Check if we're shutting down
96+
if atomic.LoadInt32(&s.shutdown) == 1 {
97+
return fmt.Errorf("client is shutting down")
98+
}
99+
90100
var err error
91101
var connectionId string
92102

@@ -102,8 +112,11 @@ func (s *SshClient) startListenerForClient() error {
102112
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
103113
}
104114

115+
// Create new client with mutex protection
116+
s.mu.Lock()
105117
s.client, err = ssh.Dial("tcp", s.config.SshUrl, sshConfig)
106118
if err != nil {
119+
s.mu.Unlock()
107120
if s.config.Debug {
108121
s.logDebug("Failed to connect to ssh server", err)
109122
}
@@ -133,31 +146,48 @@ func (s *SshClient) startListenerForClient() error {
133146
}
134147

135148
if s.listener == nil {
149+
s.mu.Unlock()
136150
return fmt.Errorf("failed to listen on remote endpoint")
137151
}
138152

139153
s.config.Tunnel.RemotePort = remotePort
154+
s.mu.Unlock()
140155

141156
defer func() {
142-
// Safe closing of listener
157+
// Safe closing of listener with mutex protection
158+
s.mu.Lock()
143159
if s.listener != nil {
144160
s.listener.Close()
161+
s.listener = nil
145162
}
163+
s.mu.Unlock()
146164
}()
147165

148-
s.tui.Send(tui.AddTunnelMsg{
149-
Config: &s.config.Tunnel,
150-
ClientConfig: &s.config,
151-
Healthy: true,
152-
})
166+
// Safe TUI send with nil check
167+
if s.tui != nil {
168+
s.tui.Send(tui.AddTunnelMsg{
169+
Config: &s.config.Tunnel,
170+
ClientConfig: &s.config,
171+
Healthy: true,
172+
})
173+
}
153174

154175
for {
155-
// Accept incoming connections on the remote port
156-
if s.listener == nil {
176+
// Check shutdown state
177+
if atomic.LoadInt32(&s.shutdown) == 1 {
178+
return fmt.Errorf("client is shutting down")
179+
}
180+
181+
// Safe listener access with read lock
182+
s.mu.RLock()
183+
listener := s.listener
184+
s.mu.RUnlock()
185+
186+
if listener == nil {
157187
return fmt.Errorf("listener is nil, cannot accept connections")
158188
}
159189

160-
remoteConn, err := s.listener.Accept()
190+
remoteConn, err := listener.Accept()
161191
if err != nil {
162192
if s.config.Debug {
163193
log.Error("Failed to accept connection", "error", err)
@@ -387,16 +417,27 @@ func (s *SshClient) tcpTunnel(src, dst net.Conn) {
387417
}
388418

389419
func (s *SshClient) Shutdown(ctx context.Context) error {
390-
if s.listener == nil {
391-
return nil
420+
// Set shutdown flag
421+
atomic.StoreInt32(&s.shutdown, 1)
422+
423+
s.mu.Lock()
424+
defer s.mu.Unlock()
425+
426+
var err error
427+
if s.listener != nil {
428+
err = s.listener.Close()
429+
s.listener = nil
392430
}
393431

394-
err := s.listener.Close()
395-
if err != nil {
396-
return err
432+
if s.client != nil {
433+
if clientErr := s.client.Close(); clientErr != nil && err == nil {
434+
err = clientErr
435+
}
436+
s.client = nil
397437
}
438+
398439
log.Info("Stopped tunnel connection", "address", s.config.GetTunnelAddr())
399-
return nil
440+
return err
400441
}
401442

402443
func (s *SshClient) StartHealthCheck(ctx context.Context) {
@@ -408,6 +449,9 @@ func (s *SshClient) StartHealthCheck(ctx context.Context) {
408449
for range ticker {
409450
retryAttempts++
410451
if retryAttempts > s.config.HealthCheckMaxRetries {
452+
if s.tui != nil {
453+
s.tui.Kill()
454+
}
411455
fmt.Printf(color.Red("Failed to reconnect to tunnel after %d attempts\n"), retryAttempts)
412456
os.Exit(1)
413457
}
@@ -478,6 +522,19 @@ func (s *SshClient) Start(ctx context.Context) {
478522
}
479523

480524
func (s *SshClient) Reconnect() error {
525+
// Prevent concurrent reconnects using atomic CAS
526+
if !atomic.CompareAndSwapInt32(&s.reconnecting, 0, 1) {
527+
return fmt.Errorf("reconnect already in progress")
528+
}
529+
defer atomic.StoreInt32(&s.reconnecting, 0)
530+
531+
// Check if we're shutting down
532+
if atomic.LoadInt32(&s.shutdown) == 1 {
533+
return fmt.Errorf("client is shutting down")
534+
}
535+
536+
// Close existing connections with mutex protection
537+
s.mu.Lock()
481538
if s.client != nil {
482539
if err := s.client.Close(); err != nil {
483540
if s.config.Debug {
@@ -495,23 +552,51 @@ func (s *SshClient) Reconnect() error {
495552
}
496553
s.listener = nil
497554
}
555+
s.mu.Unlock()
556+
557+
// Create context with timeout for the reconnection attempt
558+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
559+
defer cancel()
498560

499561
// Channel to receive errors from the goroutine
500562
errChan := make(chan error, 1)
563+
done := make(chan struct{})
501564

502-
// Start the listener in a goroutine
565+
// Start the listener in a goroutine with context
503566
go func() {
567+
defer close(done)
504568
if err := s.startListenerForClient(); err != nil {
505-
errChan <- err
569+
select {
570+
case errChan <- err:
571+
case <-ctx.Done():
572+
}
506573
}
507574
}()
508575

509-
// Wait for either an error or successful connection
576+
// Wait for either an error, successful connection, or timeout
510577
select {
511578
case err := <-errChan:
512579
return err
580+
case <-done:
581+
// Connection successful, update health status
582+
if s.tui != nil {
583+
s.tui.Send(tui.UpdateHealthMsg{
584+
Port: fmt.Sprintf("%d", s.config.Tunnel.Port),
585+
Healthy: true,
586+
})
587+
}
588+
return nil
513589
case <-time.After(5 * time.Second):
590+
// Fallback timeout in case done channel doesn't signal
591+
if s.tui != nil {
592+
s.tui.Send(tui.UpdateHealthMsg{
593+
Port: fmt.Sprintf("%d", s.config.Tunnel.Port),
594+
Healthy: true,
595+
})
596+
}
514597
return nil
598+
case <-ctx.Done():
599+
return fmt.Errorf("reconnect timeout")
515600
}
516601
}
517602

@@ -558,10 +643,13 @@ func (s *SshClient) logDebug(message string, err error) {
558643
errStr = err.Error()
559644
}
560645

561-
s.tui.Send(tui.AddDebugLogMsg{
562-
Time: time.Now().Format("15:04:05"),
563-
Level: "DEBUG",
564-
Message: message,
565-
Error: errStr,
566-
})
646+
// Safe TUI send with nil check
647+
if s.tui != nil {
648+
s.tui.Send(tui.AddDebugLogMsg{
649+
Time: time.Now().Format("15:04:05"),
650+
Level: "DEBUG",
651+
Message: message,
652+
Error: errStr,
653+
})
654+
}
567655
}

0 commit comments

Comments
 (0)