@@ -464,26 +464,35 @@ func isAzureMachinePoolWindows(amp *infrav1exp.AzureMachinePool) bool {
464464
465465// getProxiedSSHClient creates a SSH client object that connects to a target node
466466// proxied through a control plane node.
467- func getProxiedSSHClient (controlPlaneEndpoint , hostname , port string ) (* ssh.Client , error ) {
467+ func getProxiedSSHClient (controlPlaneEndpoint , hostname , port string , ioTimeout time. Duration ) (* ssh.Client , error ) {
468468 config , err := newSSHConfig ()
469469 if err != nil {
470470 return nil , err
471471 }
472472
473473 // Init a client connection to a control plane node via the public load balancer
474- lbClient , err := ssh . Dial ("tcp" , fmt .Sprintf ("%s:%s" , controlPlaneEndpoint , port ), config )
474+ c , err := net . DialTimeout ("tcp" , fmt .Sprintf ("%s:%s" , controlPlaneEndpoint , port ), config . Timeout )
475475 if err != nil {
476476 return nil , errors .Wrapf (err , "dialing public load balancer at %s" , controlPlaneEndpoint )
477477 }
478+ err = c .SetDeadline (time .Now ().Add (ioTimeout ))
479+ if err != nil {
480+ return nil , errors .Wrapf (err , "setting timeout for connection to public load balancer at %s" , controlPlaneEndpoint )
481+ }
482+ conn , chans , reqs , err := ssh .NewClientConn (c , fmt .Sprintf ("%s:%s" , controlPlaneEndpoint , port ), config )
483+ if err != nil {
484+ return nil , errors .Wrapf (err , "connecting to public load balancer at %s" , controlPlaneEndpoint )
485+ }
486+ lbClient := ssh .NewClient (conn , chans , reqs )
478487
479488 // Init a connection from the control plane to the target node
480- c , err : = lbClient .Dial ("tcp" , fmt .Sprintf ("%s:%s" , hostname , port ))
489+ c , err = lbClient .Dial ("tcp" , fmt .Sprintf ("%s:%s" , hostname , port ))
481490 if err != nil {
482491 return nil , errors .Wrapf (err , "dialing from control plane to target node at %s" , hostname )
483492 }
484493
485494 // Establish an authenticated SSH conn over the client -> control plane -> target transport
486- conn , chans , reqs , err : = ssh .NewClientConn (c , hostname , config )
495+ conn , chans , reqs , err = ssh .NewClientConn (c , hostname , config )
487496 if err != nil {
488497 return nil , errors .Wrap (err , "getting a new SSH client connection" )
489498 }
@@ -493,9 +502,9 @@ func getProxiedSSHClient(controlPlaneEndpoint, hostname, port string) (*ssh.Clie
493502
494503// execOnHost runs the specified command directly on a node's host, using a SSH connection
495504// proxied through a control plane host and copies the output to a file.
496- func execOnHost (controlPlaneEndpoint , hostname , port string , f io.StringWriter , command string ,
505+ func execOnHost (controlPlaneEndpoint , hostname , port string , ioTimeout time. Duration , f io.StringWriter , command string ,
497506 args ... string ) error {
498- client , err := getProxiedSSHClient (controlPlaneEndpoint , hostname , port )
507+ client , err := getProxiedSSHClient (controlPlaneEndpoint , hostname , port , ioTimeout )
499508 if err != nil {
500509 return err
501510 }
@@ -524,10 +533,10 @@ func execOnHost(controlPlaneEndpoint, hostname, port string, f io.StringWriter,
524533
525534// sftpCopyFile copies a file from a node to the specified destination, using a SSH connection
526535// proxied through a control plane node.
527- func sftpCopyFile (controlPlaneEndpoint , hostname , port , sourcePath , destPath string ) error {
536+ func sftpCopyFile (controlPlaneEndpoint , hostname , port string , ioTimeout time. Duration , sourcePath , destPath string ) error {
528537 Logf ("Attempting to copy file %s on node %s to %s" , sourcePath , hostname , destPath )
529538
530- client , err := getProxiedSSHClient (controlPlaneEndpoint , hostname , port )
539+ client , err := getProxiedSSHClient (controlPlaneEndpoint , hostname , port , ioTimeout )
531540 if err != nil {
532541 return err
533542 }
0 commit comments