1+ import asyncio
12import heapq
23from collections import Counter
34from contextlib import asynccontextmanager
@@ -19,20 +20,23 @@ class SFTPSoftChannelPool(BaseSFTPChannelPool):
1920
2021 def __init__ (self , * args , ** kwargs ):
2122 self ._channels = Counter ()
23+ self ._channels_lock = asyncio .Lock ()
2224 super ().__init__ (* args , ** kwargs )
2325
2426 @asynccontextmanager
2527 async def get (self ):
26- [(least_used_channel , num_connections )] = (
27- heapq .nsmallest (1 , self ._channels .items (), lambda kv : kv [1 ])
28- or self ._NO_CHANNELS
29- )
30-
28+ least_used_channel , num_connections = self ._least_used ()
3129 if least_used_channel is None or num_connections >= self ._THRESHOLD :
32- channel = await self ._maybe_new_channel ()
33- if channel is not None :
34- least_used_channel = channel
35- num_connections = 0
30+ async with self ._channels_lock :
31+ channel = await self ._maybe_new_channel ()
32+ if channel is not None :
33+ least_used_channel = channel
34+ num_connections = 0
35+ self ._channels [least_used_channel ] = 0
36+
37+ if channel is None :
38+ # another coroutine may have opened a channel while we waited
39+ least_used_channel , num_connections = self ._least_used ()
3640
3741 if least_used_channel is None :
3842 raise ValueError ("Can't create any SFTP connections!" )
@@ -46,6 +50,13 @@ async def get(self):
4650 async def _cleanup (self ):
4751 self ._channels .clear ()
4852
53+ def _least_used (self ):
54+ [(least_used_channel , num_connections )] = (
55+ heapq .nsmallest (1 , self ._channels .items (), lambda kv : kv [1 ])
56+ or self ._NO_CHANNELS
57+ )
58+ return least_used_channel , num_connections
59+
4960 @property
5061 def active_channels (self ):
5162 return len (self ._channels )
0 commit comments