@@ -50,43 +50,6 @@ func exitStatus(err error) (exitStatusMsg, error) {
50
50
return exitStatusMsg {0 }, nil
51
51
}
52
52
53
- func attachCmd (cmd * exec.Cmd , stdout io.Writer , stderr io.Writer , stdin io.Reader ) (* sync.WaitGroup , error ) {
54
- var wg sync.WaitGroup
55
- wg .Add (2 )
56
-
57
- if stdin != nil {
58
- stdinIn , err := cmd .StdinPipe ()
59
- if err != nil {
60
- return nil , err
61
- }
62
- go func () {
63
- io .Copy (stdinIn , stdin )
64
- stdinIn .Close ()
65
- // FIXME: Do we care that this is not part of the WaitGroup?
66
- }()
67
- }
68
-
69
- stdoutOut , err := cmd .StdoutPipe ()
70
- if err != nil {
71
- return nil , err
72
- }
73
- go func () {
74
- io .Copy (stdout , stdoutOut )
75
- wg .Done ()
76
- }()
77
-
78
- stderrOut , err := cmd .StderrPipe ()
79
- if err != nil {
80
- return nil , err
81
- }
82
- go func () {
83
- io .Copy (stderr , stderrOut )
84
- wg .Done ()
85
- }()
86
-
87
- return & wg , nil
88
- }
89
-
90
53
func attachShell (cmd * exec.Cmd , stdout io.Writer , stdin io.Reader ) (* os.File , * sync.WaitGroup , error ) {
91
54
var wg sync.WaitGroup
92
55
wg .Add (2 )
@@ -150,16 +113,13 @@ func parseKeys(conf *ssh.ServerConfig, pemData []byte) error {
150
113
}
151
114
152
115
func handleAuth (handler []string , conn ssh.ConnMetadata , key ssh.PublicKey ) (* ssh.Permissions , error ) {
116
+ var output bytes.Buffer
117
+
153
118
keydata := string (bytes .TrimSpace (ssh .MarshalAuthorizedKey (key )))
154
119
cmd := exec .Command (handler [0 ], append (handler [1 :], conn .User (), keydata )... )
155
- var output bytes.Buffer
156
- done , err := attachCmd (cmd , & output , & output , nil )
157
- if err != nil {
158
- return nil , err
159
- }
160
- err = cmd .Run ()
161
- done .Wait ()
162
- status , err := exitStatus (err )
120
+ cmd .Stdout = & output
121
+ cmd .Stderr = & output
122
+ status , err := exitStatus (cmd .Run ())
163
123
if err != nil {
164
124
return nil , err
165
125
}
@@ -294,8 +254,6 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
294
254
for req := range reqs {
295
255
switch req .Type {
296
256
case "exec" :
297
- defer ch .Close ()
298
-
299
257
if req .WantReply {
300
258
req .Reply (true , nil )
301
259
}
@@ -308,6 +266,7 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
308
266
} else {
309
267
cmdargs , err := shlex .Split (cmdline )
310
268
if assert ("shlex.Split" , err ) {
269
+ ch .Close ()
311
270
return
312
271
}
313
272
cmd = exec .Command (execHandler [0 ], append (execHandler [1 :], cmdargs ... )... )
@@ -325,21 +284,26 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
325
284
cmd .Env = append (cmd .Env , "USER=" + conn .Permissions .Extensions ["user" ])
326
285
}
327
286
cmd .Env = append (cmd .Env , "SSH_ORIGINAL_COMMAND=" + cmdline )
328
- done , err := attachCmd (cmd , stdout , stderr , ch )
329
- if assert ("attachCmd" , err ) {
330
- return
331
- }
332
- if assert ("cmd.Start" , cmd .Start ()) {
333
- return
334
- }
335
- done .Wait ()
336
- status , err := exitStatus (cmd .Wait ())
337
- if assert ("exitStatus" , err ) {
287
+
288
+ // cmd.Wait closes the stdin when it's done, so we need to proxy it through a pipe
289
+ stdinPipe , err := cmd .StdinPipe ()
290
+ if assert ("cmd.StdinPipe" , err ) {
291
+ ch .Close ()
338
292
return
339
293
}
340
- _ , err = ch .SendRequest ("exit-status" , false , ssh .Marshal (& status ))
341
- assert ("sendExit" , err )
342
- return
294
+ go io .Copy (stdinPipe , ch )
295
+
296
+ cmd .Stdout = stdout
297
+ cmd .Stderr = stderr
298
+
299
+ go func () {
300
+ status , err := exitStatus (cmd .Run ())
301
+ if ! assert ("exec run" , err ) {
302
+ _ , err := ch .SendRequest ("exit-status" , false , ssh .Marshal (& status ))
303
+ assert ("exec exit" , err )
304
+ }
305
+ ch .Close ()
306
+ }()
343
307
case "pty-req" :
344
308
width , height , okSize := parsePtyRequest (req .Payload )
345
309
@@ -373,9 +337,9 @@ func handleChannel(conn *ssh.ServerConn, newChan ssh.NewChannel, execHandler []s
373
337
374
338
go func () {
375
339
status , err := exitStatus (cmd .Wait ())
376
- if ! assert ("exitStatus " , err ) {
340
+ if ! assert ("pty run " , err ) {
377
341
_ , err := ch .SendRequest ("exit-status" , false , ssh .Marshal (& status ))
378
- assert ("sendExit " , err )
342
+ assert ("pty exit " , err )
379
343
}
380
344
ch .Close ()
381
345
}()
0 commit comments