Skip to content

Commit 3d59dc3

Browse files
committed
Replace attachCmd.
1 parent 46e94bd commit 3d59dc3

File tree

1 file changed

+26
-62
lines changed

1 file changed

+26
-62
lines changed

execd.go

Lines changed: 26 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -50,43 +50,6 @@ func exitStatus(err error) (exitStatusMsg, error) {
5050
return exitStatusMsg{0}, nil
5151
}
5252

53-
func attachCmd(cmd *exec.Cmd, stdout io.Writer, stderr io.Writer, stdin io.Reader) (*sync.WaitGroup, error) {
54-
var wg sync.WaitGroup
55-
wg.Add(2)
56-
57-
if stdin != nil {
58-
stdinIn, err := cmd.StdinPipe()
59-
if err != nil {
60-
return nil, err
61-
}
62-
go func() {
63-
io.Copy(stdinIn, stdin)
64-
stdinIn.Close()
65-
// FIXME: Do we care that this is not part of the WaitGroup?
66-
}()
67-
}
68-
69-
stdoutOut, err := cmd.StdoutPipe()
70-
if err != nil {
71-
return nil, err
72-
}
73-
go func() {
74-
io.Copy(stdout, stdoutOut)
75-
wg.Done()
76-
}()
77-
78-
stderrOut, err := cmd.StderrPipe()
79-
if err != nil {
80-
return nil, err
81-
}
82-
go func() {
83-
io.Copy(stderr, stderrOut)
84-
wg.Done()
85-
}()
86-
87-
return &wg, nil
88-
}
89-
9053
func attachShell(cmd *exec.Cmd, stdout io.Writer, stdin io.Reader) (*os.File, *sync.WaitGroup, error) {
9154
var wg sync.WaitGroup
9255
wg.Add(2)
@@ -150,16 +113,13 @@ func parseKeys(conf *ssh.ServerConfig, pemData []byte) error {
150113
}
151114

152115
func handleAuth(handler []string, conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
116+
var output bytes.Buffer
117+
153118
keydata := string(bytes.TrimSpace(ssh.MarshalAuthorizedKey(key)))
154119
cmd := exec.Command(handler[0], append(handler[1:], conn.User(), keydata)...)
155-
var output bytes.Buffer
156-
done, err := attachCmd(cmd, &output, &output, nil)
157-
if err != nil {
158-
return nil, err
159-
}
160-
err = cmd.Run()
161-
done.Wait()
162-
status, err := exitStatus(err)
120+
cmd.Stdout = &output
121+
cmd.Stderr = &output
122+
status, err := exitStatus(cmd.Run())
163123
if err != nil {
164124
return nil, err
165125
}
@@ -294,8 +254,6 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
294254
for req := range reqs {
295255
switch req.Type {
296256
case "exec":
297-
defer ch.Close()
298-
299257
if req.WantReply {
300258
req.Reply(true, nil)
301259
}
@@ -308,6 +266,7 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
308266
} else {
309267
cmdargs, err := shlex.Split(cmdline)
310268
if assert("shlex.Split", err) {
269+
ch.Close()
311270
return
312271
}
313272
cmd = exec.Command(execHandler[0], append(execHandler[1:], cmdargs...)...)
@@ -325,21 +284,26 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
325284
cmd.Env = append(cmd.Env, "USER="+conn.Permissions.Extensions["user"])
326285
}
327286
cmd.Env = append(cmd.Env, "SSH_ORIGINAL_COMMAND="+cmdline)
328-
done, err := attachCmd(cmd, stdout, stderr, ch)
329-
if assert("attachCmd", err) {
330-
return
331-
}
332-
if assert("cmd.Start", cmd.Start()) {
333-
return
334-
}
335-
done.Wait()
336-
status, err := exitStatus(cmd.Wait())
337-
if assert("exitStatus", err) {
287+
288+
// cmd.Wait closes the stdin when it's done, so we need to proxy it through a pipe
289+
stdinPipe, err := cmd.StdinPipe()
290+
if assert("cmd.StdinPipe", err) {
291+
ch.Close()
338292
return
339293
}
340-
_, err = ch.SendRequest("exit-status", false, ssh.Marshal(&status))
341-
assert("sendExit", err)
342-
return
294+
go io.Copy(stdinPipe, ch)
295+
296+
cmd.Stdout = stdout
297+
cmd.Stderr = stderr
298+
299+
go func() {
300+
status, err := exitStatus(cmd.Run())
301+
if !assert("exec run", err) {
302+
_, err := ch.SendRequest("exit-status", false, ssh.Marshal(&status))
303+
assert("exec exit", err)
304+
}
305+
ch.Close()
306+
}()
343307
case "pty-req":
344308
width, height, okSize := parsePtyRequest(req.Payload)
345309

@@ -373,9 +337,9 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
373337

374338
go func() {
375339
status, err := exitStatus(cmd.Wait())
376-
if !assert("exitStatus", err) {
340+
if !assert("pty run", err) {
377341
_, err := ch.SendRequest("exit-status", false, ssh.Marshal(&status))
378-
assert("sendExit", err)
342+
assert("pty exit", err)
379343
}
380344
ch.Close()
381345
}()

0 commit comments

Comments
 (0)