@@ -101,19 +101,26 @@ type sshSmartSubtransport struct {
101101 stdin io.WriteCloser
102102 stdout io.Reader
103103 currentStream * sshSmartSubtransportStream
104+ ckey string
105+ addr string
104106}
105107
106108// aMux is the read-write mutex to control access to sshClients.
107109var aMux sync.RWMutex
108110
111+ type cachedClient struct {
112+ * ssh.Client
113+ activeSessions uint16
114+ }
115+
109116// sshClients stores active ssh clients/connections to be reused.
110117//
111118// Once opened, connections will be kept cached until an error occurs
112119// during SSH commands, by which point it will be discarded, leading to
113120// a follow-up cache miss.
114121//
115122// The key must be based on cacheKey, refer to that function's comments.
116- var sshClients map [string ]* ssh. Client = make (map [string ]* ssh. Client )
123+ var sshClients map [string ]* cachedClient = make (map [string ]* cachedClient )
117124
118125func (t * sshSmartSubtransport ) Action (urlString string , action git2go.SmartServiceAction ) (git2go.SmartSubtransportStream , error ) {
119126 runtime .LockOSThread ()
@@ -124,23 +131,29 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
124131 return nil , err
125132 }
126133
134+ if len (u .Path ) > PathMaxLength {
135+ return nil , fmt .Errorf ("path exceeds the max length (%d)" , PathMaxLength )
136+ }
137+
138+ // decode URI's path
139+ uPath , err := url .PathUnescape (u .Path )
140+ if err != nil {
141+ return nil , err
142+ }
143+
127144 // Escape \ and '.
128- uPath : = strings .Replace (u . Path , `\` , `\\` , - 1 )
145+ uPath = strings .Replace (uPath , `\` , `\\` , - 1 )
129146 uPath = strings .Replace (uPath , `'` , `\'` , - 1 )
130147
131- // TODO: Add percentage decode similar to libgit2.
132- // Refer: https://github.com/libgit2/libgit2/blob/358a60e1b46000ea99ef10b4dd709e92f75ff74b/src/str.c#L455-L481
133-
134148 var cmd string
135149 switch action {
136150 case git2go .SmartServiceActionUploadpackLs , git2go .SmartServiceActionUploadpack :
137151 if t .currentStream != nil {
138152 if t .lastAction == git2go .SmartServiceActionUploadpackLs {
139153 return t .currentStream , nil
140154 }
141- if err := t .Close (); err != nil {
142- traceLog .Error (err , "[ssh]: error cleaning up previous stream" )
143- }
155+ // Disregard errors from previous stream, futher details inside Close().
156+ _ = t .Close ()
144157 }
145158 cmd = fmt .Sprintf ("git-upload-pack '%s'" , uPath )
146159
@@ -149,17 +162,16 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
149162 if t .lastAction == git2go .SmartServiceActionReceivepackLs {
150163 return t .currentStream , nil
151164 }
152- if err := t .Close (); err != nil {
153- traceLog .Error (err , "[ssh]: error cleaning up previous stream" )
154- }
165+ // Disregard errors from previous stream, futher details inside Close().
166+ _ = t .Close ()
155167 }
156168 cmd = fmt .Sprintf ("git-receive-pack '%s'" , uPath )
157169
158170 default :
159171 return nil , fmt .Errorf ("unexpected action: %v" , action )
160172 }
161173
162- cred , err := t .transport .SmartCredentials ("" , git2go .CredentialTypeSSHKey | git2go . CredentialTypeSSHMemory )
174+ cred , err := t .transport .SmartCredentials ("" , git2go .CredentialTypeSSHMemory )
163175 if err != nil {
164176 return nil , err
165177 }
@@ -171,11 +183,14 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
171183 port = u .Port ()
172184 }
173185 addr = fmt .Sprintf ("%s:%s" , u .Hostname (), port )
186+ t .addr = addr
174187
175188 ckey , sshConfig , err := cacheKeyAndConfig (addr , cred )
176189 if err != nil {
177190 return nil , err
178191 }
192+ t .ckey = ckey
193+
179194 sshConfig .HostKeyCallback = func (hostname string , remote net.Addr , key ssh.PublicKey ) error {
180195 marshaledKey := key .Marshal ()
181196 cert := & git2go.Certificate {
@@ -193,51 +208,47 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
193208 return t .transport .SmartCertificateCheck (cert , true , hostname )
194209 }
195210
196- aMux .RLock ()
211+ var cacheHit bool
212+ aMux .Lock ()
197213 if c , ok := sshClients [ckey ]; ok {
198214 traceLog .Info ("[ssh]: cache hit" , "remoteAddress" , addr )
199- t .client = c
215+ t .client = c .Client
216+ cacheHit = true
217+ c .activeSessions ++
200218 }
201- aMux .RUnlock ()
219+ aMux .Unlock ()
202220
203221 if t .client == nil {
222+ cacheHit = false
204223 traceLog .Info ("[ssh]: cache miss" , "remoteAddress" , addr )
205-
206- aMux .Lock ()
207- defer aMux .Unlock ()
208-
209- // In some scenarios the ssh handshake can hang indefinitely at
210- // golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop.
211- //
212- // xref: https://github.com/golang/go/issues/51926
213- done := make (chan error , 1 )
214- go func () {
215- t .client , err = ssh .Dial ("tcp" , addr , sshConfig )
216- done <- err
217- }()
218-
219- dialTimeout := sshConfig .Timeout + (30 * time .Second )
220-
221- select {
222- case doneErr := <- done :
223- if doneErr != nil {
224- err = fmt .Errorf ("ssh.Dial: %w" , doneErr )
225- }
226- case <- time .After (dialTimeout ):
227- err = fmt .Errorf ("timed out waiting for ssh.Dial after %s" , dialTimeout )
228- }
229-
224+ err := t .createConn (ckey , addr , sshConfig )
230225 if err != nil {
231226 return nil , err
232227 }
233-
234- sshClients [ckey ] = t .client
235228 }
236229
237230 traceLog .Info ("[ssh]: creating new ssh session" )
238231 if t .session , err = t .client .NewSession (); err != nil {
239232 discardCachedSshClient (ckey )
240- return nil , err
233+
234+ // if the current connection was cached, we can try again
235+ // as this may be a stale connection.
236+ if ! cacheHit {
237+ return nil , err
238+ }
239+
240+ traceLog .Info ("[ssh]: cached connection was stale, retrying..." )
241+ err = t .createConn (ckey , addr , sshConfig )
242+ if err != nil {
243+ return nil , err
244+ }
245+
246+ traceLog .Info ("[ssh]: creating new ssh session with new connection" )
247+ t .session , err = t .client .NewSession ()
248+ if err != nil {
249+ discardCachedSshClient (ckey )
250+ return nil , err
251+ }
241252 }
242253
243254 if t .stdin , err = t .session .StdinPipe (); err != nil {
@@ -264,28 +275,83 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
264275 return t .currentStream , nil
265276}
266277
267- func (t * sshSmartSubtransport ) Close () error {
268- var returnErr error
278+ func (t * sshSmartSubtransport ) createConn (ckey , addr string , sshConfig * ssh.ClientConfig ) error {
279+ // In some scenarios the ssh handshake can hang indefinitely at
280+ // golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop.
281+ //
282+ // xref: https://github.com/golang/go/issues/51926
283+ done := make (chan error , 1 )
284+ var err error
285+
286+ var c * ssh.Client
287+ go func () {
288+ c , err = ssh .Dial ("tcp" , addr , sshConfig )
289+ done <- err
290+ }()
291+
292+ dialTimeout := sshConfig .Timeout + (30 * time .Second )
293+
294+ select {
295+ case doneErr := <- done :
296+ if doneErr != nil {
297+ err = fmt .Errorf ("ssh.Dial: %w" , doneErr )
298+ }
299+ case <- time .After (dialTimeout ):
300+ err = fmt .Errorf ("timed out waiting for ssh.Dial after %s" , dialTimeout )
301+ }
302+
303+ if err != nil {
304+ return err
305+ }
306+
307+ t .client = c
308+
309+ // Mutex is set here to avoid the network latency being
310+ // absorbed by all competing goroutines.
311+ aMux .Lock ()
312+ defer aMux .Unlock ()
313+
314+ // A different goroutine won the race, dispose the connection
315+ // and carry on.
316+ if _ , ok := sshClients [ckey ]; ok {
317+ go func () {
318+ _ = c .Close ()
319+ }()
320+ return nil
321+ }
322+
323+ sshClients [ckey ] = & cachedClient {
324+ Client : c ,
325+ activeSessions : 1 ,
326+ }
327+
328+ return nil
329+ }
269330
270- traceLog .Info ("[ssh]: sshSmartSubtransport.Close()" )
331+ // Close closes the smart subtransport.
332+ //
333+ // This is called internally ahead of a new action, and also
334+ // upstream by the transport handler:
335+ // https://github.com/libgit2/git2go/blob/0e8009f00a65034d196c67b1cdd82af6f12c34d3/transport.go#L409
336+ //
337+ // Avoid returning errors, but focus on releasing anything that
338+ // may impair the transport to have successful actions on a new
339+ // SmartSubTransport (i.e. unreleased resources, staled connections).
340+ func (t * sshSmartSubtransport ) Close () error {
341+ traceLog .Info ("[ssh]: sshSmartSubtransport.Close()" , "server" , t .addr )
271342 t .currentStream = nil
272343 if t .client != nil && t .stdin != nil {
273- if err := t .stdin .Close (); err != nil {
274- returnErr = fmt .Errorf ("cannot close stdin: %w" , err )
275- }
344+ _ = t .stdin .Close ()
276345 }
277346 t .client = nil
278347
279348 if t .session != nil {
280- traceLog .Info ("[ssh]: skipping session.wait" )
281- traceLog .Info ("[ssh]: session.Close()" )
282- if err := t .session .Close (); err != nil {
283- returnErr = fmt .Errorf ("cannot close session: %w" , err )
284- }
349+ traceLog .Info ("[ssh]: session.Close()" , "server" , t .addr )
350+ _ = t .session .Close ()
285351 }
286352 t .session = nil
287353
288- return returnErr
354+ return nil
289355}
290356
291357func (t * sshSmartSubtransport ) Free () {
@@ -306,6 +372,13 @@ func (stream *sshSmartSubtransportStream) Write(buf []byte) (int, error) {
306372
307373func (stream * sshSmartSubtransportStream ) Free () {
308374 traceLog .Info ("[ssh]: sshSmartSubtransportStream.Free()" )
375+ if stream .owner == nil {
376+ return
377+ }
378+
379+ if stream .owner .ckey != "" {
380+ decrementActiveSessionIfFound (stream .owner .ckey )
381+ }
309382}
310383
311384func cacheKeyAndConfig (remoteAddress string , cred * git2go.Credential ) (string , * ssh.ClientConfig , error ) {
@@ -376,8 +449,41 @@ func discardCachedSshClient(key string) {
376449 aMux .Lock ()
377450 defer aMux .Unlock ()
378451
379- if _ , found := sshClients [key ]; found {
380- traceLog .Info ("[ssh]: discard cached ssh client" )
452+ if v , found := sshClients [key ]; found {
453+ traceLog .Info ("[ssh]: discard cached ssh client" , "activeSessions" , v .activeSessions )
454+ closeConn := func () {
455+ // run as async goroutine to minimise mutex time in immediate closures.
456+ go func () {
457+ if v .Client != nil {
458+ _ = v .Client .Close ()
459+ }
460+ }()
461+ }
462+
463+ // if no active sessions for this connection, close it right-away.
464+ // otherwise, it may be used by other processes, so remove from cache,
465+ // and schedule a delayed closure.
466+ if v .activeSessions == 0 {
467+ traceLog .Info ("[ssh]: closing connection" )
468+ closeConn ()
469+ } else {
470+ go func () {
471+ // the delay must account for in-flight operations
472+ // that depends on this connection.
473+ time .Sleep (120 * time .Second )
474+ traceLog .Info ("[ssh]: closing connection after delay" )
475+ closeConn ()
476+ }()
477+ }
381478 delete (sshClients , key )
382479 }
383480}
481+
482+ func decrementActiveSessionIfFound (key string ) {
483+ aMux .Lock ()
484+ defer aMux .Unlock ()
485+
486+ if v , found := sshClients [key ]; found {
487+ v .activeSessions --
488+ }
489+ }
0 commit comments