Skip to content

Commit 40b0a4f

Browse files
rumpelseppFahrj
authored andcommitted
Remove manual process management in favor of CommandContext()
Golang provides the function exec.CommandContext(). Using this, spawned resources can be linked to the session context and are cleaned up automatically at session close.
1 parent 6ba9584 commit 40b0a4f

File tree

3 files changed

+23
-35
lines changed

3 files changed

+23
-35
lines changed

ssh_session.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ func createSSHSessionHandler(shell string) ssh.Handler {
3636
createPty(s, shell)
3737

3838
case len(s.Command()) > 0:
39-
log.Printf("No PTY requested, executing command: '%s'", s.RawCommand())
39+
log.Printf("Command execution requested: '%s'", s.RawCommand())
40+
41+
cmd := exec.CommandContext(s.Context(), s.Command()[0], s.Command()[1:]...)
4042

41-
cmd := exec.Command(s.Command()[0], s.Command()[1:]...)
4243
// We use StdinPipe to avoid blocking on missing input
4344
if stdin, err := cmd.StdinPipe(); err != nil {
4445
log.Println("Could not initialize stdinPipe", err)
@@ -49,11 +50,10 @@ func createSSHSessionHandler(shell string) ssh.Handler {
4950
if _, err := io.Copy(stdin, s); err != nil {
5051
log.Printf("Error while copying input from %s to stdin: %s", s.RemoteAddr().String(), err)
5152
}
52-
if err := stdin.Close(); err != nil {
53-
log.Println("Error while closing stdinPipe:", err)
54-
}
53+
s.Close()
5554
}()
5655
}
56+
5757
cmd.Stdout = s
5858
cmd.Stderr = s
5959

@@ -64,27 +64,24 @@ func createSSHSessionHandler(shell string) ssh.Handler {
6464
case err := <-done:
6565
if err != nil {
6666
log.Println("Command execution failed:", err)
67-
io.WriteString(s, "Command execution failed: "+err.Error())
67+
io.WriteString(s, "Command execution failed: "+err.Error()+"\n")
6868
} else {
6969
log.Println("Command execution successful")
7070
}
7171
s.Exit(cmd.ProcessState.ExitCode())
7272

7373
case <-s.Context().Done():
74-
log.Println("Session closed by remote, killing dangling process")
75-
if cmd.Process != nil && cmd.ProcessState == nil {
76-
if err := cmd.Process.Kill(); err != nil {
77-
log.Println("Failed to kill process:", err)
78-
}
79-
}
74+
log.Printf("Session terminated: %s", s.Context().Err())
75+
return
8076
}
8177

8278
default:
8379
log.Println("No PTY requested, no command supplied")
8480

81+
// Keep this open until the session exits, could e.g. be port forwarding
8582
select {
8683
case <-s.Context().Done():
87-
log.Println("Session closed")
84+
log.Printf("Session terminated: %s", s.Context().Err())
8885
}
8986
}
9087
}

ssh_session_unix.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ import (
3030
)
3131

3232
func createPty(s ssh.Session, shell string) {
33-
ptyReq, winCh, _ := s.Pty()
33+
var (
34+
ptyReq, winCh, _ = s.Pty()
35+
cmd = exec.CommandContext(s.Context(), shell)
36+
)
3437

35-
cmd := exec.Command(shell)
3638
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
3739
f, err := pty.Start(cmd)
3840
if err != nil {
@@ -61,11 +63,7 @@ func createPty(s ssh.Session, shell string) {
6163
s.Exit(cmd.ProcessState.ExitCode())
6264

6365
case <-s.Context().Done():
64-
log.Println("Session closed by remote, killing dangling process")
65-
if cmd.Process != nil && cmd.ProcessState == nil {
66-
if err := cmd.Process.Kill(); err != nil {
67-
log.Println("Failed to kill process:", err)
68-
}
69-
}
66+
log.Printf("Session terminated: %s", s.Context().Err())
67+
return
7068
}
7169
}

ssh_session_windows.go

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func createPty(s ssh.Session, shell string) {
5151
}
5252
log.Println("Launching shell with ssh-shellhost.exe")
5353

54-
cmd := exec.Command(shell)
54+
cmd := exec.CommandContext(s.Context(), shell)
5555
cmd.SysProcAttr = &syscall.SysProcAttr{
5656
HideWindow: true,
5757
CmdLine: " " + "---pty cmd", // Must leave a space to the beginning
@@ -67,9 +67,7 @@ func createPty(s ssh.Session, shell string) {
6767
if _, err := io.Copy(stdin, s); err != nil {
6868
log.Printf("Error while copying input from %s to stdin: %s", s.RemoteAddr().String(), err)
6969
}
70-
if err := stdin.Close(); err != nil {
71-
log.Println("Error while closing stdinPipe:", err)
72-
}
70+
s.Close()
7371
}()
7472
}
7573
cmd.Stdout = s
@@ -88,12 +86,8 @@ func createPty(s ssh.Session, shell string) {
8886
s.Exit(cmd.ProcessState.ExitCode())
8987

9088
case <-s.Context().Done():
91-
log.Println("Session closed by remote, killing dangling process")
92-
if cmd.Process != nil && cmd.ProcessState == nil {
93-
if err := cmd.Process.Kill(); err != nil {
94-
log.Println("Failed to kill process:", err)
95-
}
96-
}
89+
log.Printf("Session terminated: %s", s.Context().Err())
90+
return
9791
}
9892

9993
} else {
@@ -128,6 +122,7 @@ func createPty(s ssh.Session, shell string) {
128122
if err != nil {
129123
log.Fatalf("Failed to find process: %v", err)
130124
}
125+
defer process.Kill()
131126

132127
// Link data streams of ssh session and conpty
133128
go io.Copy(s, cpty.OutPipe())
@@ -156,10 +151,8 @@ func createPty(s ssh.Session, shell string) {
156151
s.Exit(result.ProcessState.ExitCode())
157152

158153
case <-s.Context().Done():
159-
log.Println("Session closed by remote, killing process")
160-
if err := process.Kill(); err != nil {
161-
log.Println("Failed to kill process:", err)
162-
}
154+
log.Printf("Session terminated: %s", s.Context().Err())
155+
return
163156
}
164157
}
165158
}

0 commit comments

Comments
 (0)