@@ -10,6 +10,8 @@ import (
1010 "net"
1111 "net/http"
1212 "os"
13+ "sync"
14+ "sync/atomic"
1315 "time"
1416
1517 "github.com/amalshaji/portr/internal/client/config"
@@ -32,11 +34,14 @@ var (
3234)
3335
3436type SshClient struct {
35- config config.ClientConfig
36- listener net.Listener
37- db * db.Db
38- client * ssh.Client
39- tui * tea.Program
37+ config config.ClientConfig
38+ listener net.Listener
39+ db * db.Db
40+ client * ssh.Client
41+ tui * tea.Program
42+ mu sync.RWMutex
43+ reconnecting int32 // atomic flag to prevent concurrent reconnects
44+ shutdown int32 // atomic flag for shutdown state
4045}
4146
4247func New (config config.ClientConfig , db * db.Db , tui * tea.Program ) * SshClient {
@@ -87,6 +92,11 @@ func (s *SshClient) createNewConnection() (string, error) {
8792}
8893
8994func (s * SshClient ) startListenerForClient () error {
95+ // Check if we're shutting down
96+ if atomic .LoadInt32 (& s .shutdown ) == 1 {
97+ return fmt .Errorf ("client is shutting down" )
98+ }
99+
90100 var err error
91101 var connectionId string
92102
@@ -102,8 +112,11 @@ func (s *SshClient) startListenerForClient() error {
102112 HostKeyCallback : ssh .InsecureIgnoreHostKey (),
103113 }
104114
115+ // Create new client with mutex protection
116+ s .mu .Lock ()
105117 s .client , err = ssh .Dial ("tcp" , s .config .SshUrl , sshConfig )
106118 if err != nil {
119+ s .mu .Unlock ()
107120 if s .config .Debug {
108121 s .logDebug ("Failed to connect to ssh server" , err )
109122 }
@@ -133,31 +146,48 @@ func (s *SshClient) startListenerForClient() error {
133146 }
134147
135148 if s .listener == nil {
149+ s .mu .Unlock ()
136150 return fmt .Errorf ("failed to listen on remote endpoint" )
137151 }
138152
139153 s .config .Tunnel .RemotePort = remotePort
154+ s .mu .Unlock ()
140155
141156 defer func () {
142- // Safe closing of listener
157+ // Safe closing of listener with mutex protection
158+ s .mu .Lock ()
143159 if s .listener != nil {
144160 s .listener .Close ()
161+ s .listener = nil
145162 }
163+ s .mu .Unlock ()
146164 }()
147165
148- s .tui .Send (tui.AddTunnelMsg {
149- Config : & s .config .Tunnel ,
150- ClientConfig : & s .config ,
151- Healthy : true ,
152- })
166+ // Safe TUI send with nil check
167+ if s .tui != nil {
168+ s .tui .Send (tui.AddTunnelMsg {
169+ Config : & s .config .Tunnel ,
170+ ClientConfig : & s .config ,
171+ Healthy : true ,
172+ })
173+ }
153174
154175 for {
155- // Accept incoming connections on the remote port
156- if s .listener == nil {
176+ // Check shutdown state
177+ if atomic .LoadInt32 (& s .shutdown ) == 1 {
178+ return fmt .Errorf ("client is shutting down" )
179+ }
180+
181+ // Safe listener access with read lock
182+ s .mu .RLock ()
183+ listener := s .listener
184+ s .mu .RUnlock ()
185+
186+ if listener == nil {
157187 return fmt .Errorf ("listener is nil, cannot accept connections" )
158188 }
159189
160- remoteConn , err := s . listener .Accept ()
190+ remoteConn , err := listener .Accept ()
161191 if err != nil {
162192 if s .config .Debug {
163193 log .Error ("Failed to accept connection" , "error" , err )
@@ -387,16 +417,27 @@ func (s *SshClient) tcpTunnel(src, dst net.Conn) {
387417}
388418
389419func (s * SshClient ) Shutdown (ctx context.Context ) error {
390- if s .listener == nil {
391- return nil
420+ // Set shutdown flag
421+ atomic .StoreInt32 (& s .shutdown , 1 )
422+
423+ s .mu .Lock ()
424+ defer s .mu .Unlock ()
425+
426+ var err error
427+ if s .listener != nil {
428+ err = s .listener .Close ()
429+ s .listener = nil
392430 }
393431
394- err := s .listener .Close ()
395- if err != nil {
396- return err
432+ if s .client != nil {
433+ if clientErr := s .client .Close (); clientErr != nil && err == nil {
434+ err = clientErr
435+ }
436+ s .client = nil
397437 }
438+
398439 log .Info ("Stopped tunnel connection" , "address" , s .config .GetTunnelAddr ())
399- return nil
440+ return err
400441}
401442
402443func (s * SshClient ) StartHealthCheck (ctx context.Context ) {
@@ -408,6 +449,9 @@ func (s *SshClient) StartHealthCheck(ctx context.Context) {
408449 for range ticker {
409450 retryAttempts ++
410451 if retryAttempts > s .config .HealthCheckMaxRetries {
452+ if s .tui != nil {
453+ s .tui .Kill ()
454+ }
411455 fmt .Printf (color .Red ("Failed to reconnect to tunnel after %d attempts\n " ), retryAttempts )
412456 os .Exit (1 )
413457 }
@@ -478,6 +522,19 @@ func (s *SshClient) Start(ctx context.Context) {
478522}
479523
480524func (s * SshClient ) Reconnect () error {
525+ // Prevent concurrent reconnects using atomic CAS
526+ if ! atomic .CompareAndSwapInt32 (& s .reconnecting , 0 , 1 ) {
527+ return fmt .Errorf ("reconnect already in progress" )
528+ }
529+ defer atomic .StoreInt32 (& s .reconnecting , 0 )
530+
531+ // Check if we're shutting down
532+ if atomic .LoadInt32 (& s .shutdown ) == 1 {
533+ return fmt .Errorf ("client is shutting down" )
534+ }
535+
536+ // Close existing connections with mutex protection
537+ s .mu .Lock ()
481538 if s .client != nil {
482539 if err := s .client .Close (); err != nil {
483540 if s .config .Debug {
@@ -495,23 +552,51 @@ func (s *SshClient) Reconnect() error {
495552 }
496553 s .listener = nil
497554 }
555+ s .mu .Unlock ()
556+
557+ // Create context with timeout for the reconnection attempt
558+ ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
559+ defer cancel ()
498560
499561 // Channel to receive errors from the goroutine
500562 errChan := make (chan error , 1 )
563+ done := make (chan struct {})
501564
502- // Start the listener in a goroutine
565+ // Start the listener in a goroutine with context
503566 go func () {
567+ defer close (done )
504568 if err := s .startListenerForClient (); err != nil {
505- errChan <- err
569+ select {
570+ case errChan <- err :
571+ case <- ctx .Done ():
572+ }
506573 }
507574 }()
508575
509- // Wait for either an error or successful connection
576+ // Wait for either an error, successful connection, or timeout
510577 select {
511578 case err := <- errChan :
512579 return err
580+ case <- done :
581+ // Connection successful, update health status
582+ if s .tui != nil {
583+ s .tui .Send (tui.UpdateHealthMsg {
584+ Port : fmt .Sprintf ("%d" , s .config .Tunnel .Port ),
585+ Healthy : true ,
586+ })
587+ }
588+ return nil
513589 case <- time .After (5 * time .Second ):
590+ // Fallback timeout in case done channel doesn't signal
591+ if s .tui != nil {
592+ s .tui .Send (tui.UpdateHealthMsg {
593+ Port : fmt .Sprintf ("%d" , s .config .Tunnel .Port ),
594+ Healthy : true ,
595+ })
596+ }
514597 return nil
598+ case <- ctx .Done ():
599+ return fmt .Errorf ("reconnect timeout" )
515600 }
516601}
517602
@@ -558,10 +643,13 @@ func (s *SshClient) logDebug(message string, err error) {
558643 errStr = err .Error ()
559644 }
560645
561- s .tui .Send (tui.AddDebugLogMsg {
562- Time : time .Now ().Format ("15:04:05" ),
563- Level : "DEBUG" ,
564- Message : message ,
565- Error : errStr ,
566- })
646+ // Safe TUI send with nil check
647+ if s .tui != nil {
648+ s .tui .Send (tui.AddDebugLogMsg {
649+ Time : time .Now ().Format ("15:04:05" ),
650+ Level : "DEBUG" ,
651+ Message : message ,
652+ Error : errStr ,
653+ })
654+ }
567655}
0 commit comments