77	"fmt" 
88	"io" 
99	"reflect" 
10+ 	"sync" 
1011
1112	sftpClient "github.com/pkg/sftp" 
1213	"golang.org/x/crypto/ssh" 
@@ -25,9 +26,30 @@ const (
2526
2627// Sftp is a binding for file operations on sftp server. 
2728type  Sftp  struct  {
28- 	metadata    * sftpMetadata 
29- 	logger      logger.Logger 
30- 	sftpClient  * sftpClient.Client 
29+ 	metadata      * sftpMetadata 
30+ 	logger        logger.Logger 
31+ 	sftpClient    * sftpClient.Client 
32+ 	sshClient     * ssh.Client 
33+ 	clientConfig  * ssh.ClientConfig 
34+ 	lock          sync.RWMutex 
35+ }
36+ 
37+ func  (sftp  * Sftp ) Client () (* sftpClient.Client , error ) {
38+ 	sftp .lock .RLock ()
39+ 	current  :=  sftp .sftpClient 
40+ 	sftp .lock .RUnlock ()
41+ 
42+ 	if  current  !=  nil  {
43+ 		if  _ , err  :=  current .Getwd (); err  ==  nil  {
44+ 			return  current , nil 
45+ 		}
46+ 	}
47+ 
48+ 	err  :=  sftp .handleReconnection ()
49+ 	if  err  !=  nil  {
50+ 		return  nil , err 
51+ 	}
52+ 	return  sftp .sftpClient , nil 
3153}
3254
3355// sftpMetadata defines the sftp metadata. 
@@ -122,9 +144,11 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
122144
123145	newSftpClient , err  :=  sftpClient .NewClient (sshClient )
124146	if  err  !=  nil  {
147+ 		_  =  sshClient .Close ()
125148		return  fmt .Errorf ("sftp binding error: error create sftp client: %w" , err )
126149	}
127150
151+ 	sftp .clientConfig  =  config 
128152	sftp .metadata  =  m 
129153	sftp .sftpClient  =  newSftpClient 
130154
@@ -163,12 +187,17 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
163187
164188	dir , fileName  :=  sftpClient .Split (path )
165189
166- 	err  =  sftp .sftpClient .MkdirAll (dir )
190+ 	c , err  :=  sftp .Client ()
191+ 	if  err  !=  nil  {
192+ 		return  nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
193+ 	}
194+ 
195+ 	err  =  c .MkdirAll (dir )
167196	if  err  !=  nil  {
168197		return  nil , fmt .Errorf ("sftp binding error: error create dir %s: %w" , dir , err )
169198	}
170199
171- 	file , err  :=  sftp . sftpClient .Create (path )
200+ 	file , err  :=  c .Create (path )
172201	if  err  !=  nil  {
173202		return  nil , fmt .Errorf ("sftp binding error: error create file %s: %w" , path , err )
174203	}
@@ -211,7 +240,12 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
211240		return  nil , fmt .Errorf ("sftp binding error: %w" , err )
212241	}
213242
214- 	files , err  :=  sftp .sftpClient .ReadDir (path )
243+ 	c , err  :=  sftp .Client ()
244+ 	if  err  !=  nil  {
245+ 		return  nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
246+ 	}
247+ 
248+ 	files , err  :=  c .ReadDir (path )
215249	if  err  !=  nil  {
216250		return  nil , fmt .Errorf ("sftp binding error: error read dir %s: %w" , path , err )
217251	}
@@ -246,7 +280,12 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
246280		return  nil , fmt .Errorf ("sftp binding error: %w" , err )
247281	}
248282
249- 	file , err  :=  sftp .sftpClient .Open (path )
283+ 	c , err  :=  sftp .Client ()
284+ 	if  err  !=  nil  {
285+ 		return  nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
286+ 	}
287+ 
288+ 	file , err  :=  c .Open (path )
250289	if  err  !=  nil  {
251290		return  nil , fmt .Errorf ("sftp binding error: error open file %s: %w" , path , err )
252291	}
@@ -272,7 +311,11 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
272311		return  nil , fmt .Errorf ("sftp binding error: %w" , err )
273312	}
274313
275- 	err  =  sftp .sftpClient .Remove (path )
314+ 	c , err  :=  sftp .Client ()
315+ 	if  err  !=  nil  {
316+ 		return  nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
317+ 	}
318+ 	err  =  c .Remove (path )
276319	if  err  !=  nil  {
277320		return  nil , fmt .Errorf ("sftp binding error: error remove file %s: %w" , path , err )
278321	}
@@ -296,6 +339,8 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
296339}
297340
298341func  (sftp  * Sftp ) Close () error  {
342+ 	sftp .lock .Lock ()
343+ 	defer  sftp .lock .Unlock ()
299344	return  sftp .sftpClient .Close ()
300345}
301346
@@ -330,3 +375,41 @@ func (sftp *Sftp) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
330375	metadata .GetMetadataInfoFromStructType (reflect .TypeOf (metadataStruct ), & metadataInfo , metadata .BindingType )
331376	return 
332377}
378+ 
379+ func  (sftp  * Sftp ) handleReconnection () error  {
380+ 	sftp .lock .Lock ()
381+ 	defer  sftp .lock .Unlock ()
382+ 
383+ 	// Re-check after acquiring the write lock 
384+ 	if  sftp .sftpClient  !=  nil  {
385+ 		if  _ , err  :=  sftp .sftpClient .Getwd (); err  ==  nil  {
386+ 			return  nil 
387+ 		}
388+ 		_  =  sftp .sftpClient .Close ()
389+ 		sftp .sftpClient  =  nil 
390+ 	}
391+ 	if  sftp .sshClient  !=  nil  {
392+ 		_  =  sftp .sshClient .Close ()
393+ 		sftp .sshClient  =  nil 
394+ 	}
395+ 
396+ 	if  sftp .metadata  ==  nil  ||  sftp .clientConfig  ==  nil  {
397+ 		return  fmt .Errorf ("sftp binding error: client not initialized" )
398+ 	}
399+ 
400+ 	sshClient , err  :=  ssh .Dial ("tcp" , sftp .metadata .Address , sftp .clientConfig )
401+ 	if  err  !=  nil  {
402+ 		return  fmt .Errorf ("sftp binding error: error create ssh client: %w" , err )
403+ 	}
404+ 
405+ 	newSftpClient , err  :=  sftpClient .NewClient (sshClient )
406+ 	if  err  !=  nil  {
407+ 		_  =  sshClient .Close ()
408+ 		return  fmt .Errorf ("sftp binding error: error create sftp client: %w" , err )
409+ 	}
410+ 
411+ 	sftp .sshClient  =  sshClient 
412+ 	sftp .sftpClient  =  newSftpClient 
413+ 
414+ 	return  nil 
415+ }
0 commit comments