@@ -109,9 +109,14 @@ async def initialize(self, **kwargs) -> TNode:
109109 self ._current_exception = None
110110 self .api_key = None
111111 self ._stdin = self ._stdout = self ._long_proc = None
112- self ._retry = True # set to False if authentication fails
112+ self ._max_retries_on_auth_fail = (kwargs .get ('retries_on_auth_fail' )
113+ or 0 ) + 1
114+ self ._retry = self ._max_retries_on_auth_fail
113115 self ._discovery_lock = asyncio .Lock ()
114116 self ._cmd_sem = kwargs .get ('cmd_sem' , None )
117+ self ._cmd_mutex = kwargs .get ('cmd_mutex' , None )
118+ self ._cmd_pacer_sleep = kwargs .get ('cmd_pacer_sleep' , None )
119+ self .per_cmd_auth = kwargs .get ('per_cmd_auth' , True )
115120
116121 self .address = kwargs ["address" ]
117122 self .hostname = kwargs ["address" ] # default till we get hostname
@@ -242,31 +247,33 @@ def is_connected(self):
242247 return self ._conn is not None
243248
244249 @asynccontextmanager
245- async def limit_pipeline (self ):
250+ async def cmd_pacer (self , use_sem : bool = True ):
246251 '''Context Manager to implement throttling of commands.
247252
248253 In many networks, backend authentication servers such as TACACS which
249254 handle authentication of logins and even command execution, cannot
250255 large volumes of authentication requests. Thanks to our use of
251256 asyncio, we can easily sends hundreds of connection requests to such
252257 servers, which effectively turns into authentication failures. To
253- handle this, we add a user-specified maximum of simultaneous
254- commands/logins at any given time. This code implements that locking
255- context.
258+ handle this, we add a user-specified maximum of rate of cmds/sec
259+ that the authentication can handle, and we pace it out. This code
260+ implements that pacer.
261+
262+ Some networks communicate with a backend authentication server only
263+ on login while others contact it for authorization of a command as
264+ well. Its to handle this difference that we pass use_sem. Users set
265+ the per_cmd_auth to True if authorization is used. The caller of this
266+ function sets the use_sem apppropriately depending on when the context
267+ is invoked.
268+
269+ Args:
270+ use_sem(bool): True if you want to use the pacer
256271 '''
257- if self ._cmd_sem :
258- self .logger .debug (
259- f'{ self .transport } ://{ self .hostname } :{ self .port } : Get lock' )
260- await self ._cmd_sem .acquire ()
261- self .logger .debug (
262- f'{ self .transport } ://{ self .hostname } :{ self .port } : Got lock' )
263- try :
272+ if self ._cmd_sem and use_sem :
273+ async with self ._cmd_sem :
274+ async with self ._cmd_mutex :
275+ await asyncio .sleep (self ._cmd_pacer_sleep )
264276 yield
265- finally :
266- self .logger .debug (
267- f'{ self .transport } ://{ self .hostname } :{ self .port } : '
268- 'Free lock' )
269- self ._cmd_sem .release ()
270277 else :
271278 yield
272279
@@ -643,7 +650,7 @@ async def _init_ssh(self, init_dev_data=True, use_lock=True) -> None:
643650 self .ssh_ready .release ()
644651 return
645652
646- async with self .limit_pipeline ():
653+ async with self .cmd_pacer ():
647654 try :
648655 if self ._tunnel :
649656 self ._conn = await self ._tunnel .connect_ssh (
@@ -659,6 +666,8 @@ async def _init_ssh(self, init_dev_data=True, use_lock=True) -> None:
659666 self .logger .info (
660667 f"Connected to { self .address } :{ self .port } at "
661668 f"{ time .time ()} " )
669+ # Reset authentication fail attempt on success
670+ self ._retry = self ._max_retries_on_auth_fail
662671 except Exception as e : # pylint: disable=broad-except
663672 if isinstance (e , asyncssh .HostKeyNotVerifiable ):
664673 self .logger .error (
@@ -672,7 +681,7 @@ async def _init_ssh(self, init_dev_data=True, use_lock=True) -> None:
672681 f'Authentication failed to { self .address } . '
673682 'Not retrying to avoid locking out user. Please '
674683 'restart poller with proper authentication' )
675- self ._retry = False
684+ self ._retry -= 1
676685 else :
677686 self .logger .error ('Unable to connect to '
678687 f'{ self .address } :{ self .port } , { e } ' )
@@ -790,7 +799,7 @@ async def _ssh_gather(self, service_callback: Callable,
790799 cb_token .node_token = self .bootupTimestamp
791800
792801 timeout = timeout or self .cmd_timeout
793- async with self .limit_pipeline ( ):
802+ async with self .cmd_pacer ( self . per_cmd_auth ):
794803 for cmd in cmd_list :
795804 try :
796805 output = await asyncio .wait_for (self ._conn .run (cmd ),
@@ -1162,7 +1171,7 @@ async def _rest_gather(self, service_callback, cmd_list, cb_token,
11621171 output = []
11631172 status = 200 # status OK
11641173
1165- async with self .limit_pipeline ( ):
1174+ async with self .cmd_pacer ( self . per_cmd_auth ):
11661175 try :
11671176 async with aiohttp .ClientSession (
11681177 auth = auth , conn_timeout = self .connect_timeout ,
@@ -1309,7 +1318,7 @@ async def _init_rest(self):
13091318 url = "https://{0}:{1}/nclu/v1/rpc" .format (self .address , self .port )
13101319 headers = {"Content-Type" : "application/json" }
13111320
1312- async with self .limit_pipeline ( ):
1321+ async with self .cmd_pacer ( self . per_cmd_auth ):
13131322 try :
13141323 async with aiohttp .ClientSession (
13151324 auth = auth , timeout = self .cmd_timeout ,
@@ -1334,7 +1343,7 @@ async def _rest_gather(self, service_callback, cmd_list, cb_token,
13341343 url = "https://{0}:{1}/nclu/v1/rpc" .format (self .address , self .port )
13351344 headers = {"Content-Type" : "application/json" }
13361345
1337- async with self .limit_pipeline ( ):
1346+ async with self .cmd_pacer ( self . per_cmd_auth ):
13381347 try :
13391348 async with aiohttp .ClientSession (
13401349 auth = auth ,
@@ -1524,7 +1533,7 @@ async def _init_ssh(self, init_dev_data=True,
15241533 if self .is_connected and not self ._stdin :
15251534 self .logger .info (
15261535 f'Trying to create Persistent SSH for { self .hostname } ' )
1527- async with self .limit_pipeline ( ):
1536+ async with self .cmd_pacer ( self . per_cmd_auth ):
15281537 try :
15291538 self ._stdin , self ._stdout , self ._stderr = \
15301539 await self ._conn .open_session (term_type = 'xterm' )
@@ -1537,11 +1546,15 @@ async def _init_ssh(self, init_dev_data=True,
15371546 await self ._close_connection ()
15381547 self ._conn = None
15391548 self ._stdin = None
1540- self ._retry = False # No retry if escalation fails
1549+ self ._retry -= 1
15411550 if use_lock :
15421551 self .ssh_ready .release ()
15431552 return
1553+ # Reset number of retries on successful auth
1554+ self ._retry = self ._max_retries_on_auth_fail
15441555 except Exception as e :
1556+ if isinstance (e , asyncssh .misc .PermissionDenied ):
1557+ self ._retry -= 1
15451558 self .current_exception = e
15461559 self .logger .error ('Unable to create persistent SSH session'
15471560 f' for { self .hostname } due to { str (e )} ' )
@@ -1657,7 +1670,7 @@ async def _ssh_gather(self, service_callback, cmd_list, cb_token, oformat,
16571670 return
16581671
16591672 timeout = timeout or self .cmd_timeout
1660- async with self .limit_pipeline ( ):
1673+ async with self .cmd_pacer ( self . per_cmd_auth ):
16611674 for cmd in cmd_list :
16621675 try :
16631676 if self .slow_host :
@@ -1885,7 +1898,7 @@ async def _fetch_init_dev_data(self):
18851898 try :
18861899 res = []
18871900 # temporary hack to detect device info using ssh
1888- async with self .limit_pipeline ():
1901+ async with self .cmd_pacer ():
18891902 async with asyncssh .connect (
18901903 self .address , port = 22 , username = self .username ,
18911904 password = self .password , known_hosts = None ) as conn :
@@ -1922,7 +1935,7 @@ async def get_api_key(self):
19221935 url = f"https://{ self .address } :{ self .port } /api/?type=keygen&user=" \
19231936 f"{ self .username } &password={ self .password } "
19241937
1925- async with self .limit_pipeline ( ):
1938+ async with self .cmd_pacer ( self . per_cmd_auth ):
19261939 async with self ._session .get (url , timeout = self .connect_timeout ) \
19271940 as response :
19281941 status , xml = response .status , await response .text ()
@@ -1974,7 +1987,7 @@ def _extract_nos_version(self, data: str) -> None:
19741987 async def _init_rest (self ):
19751988 # In case of PANOS, getting here means REST is up
19761989 if not self ._session :
1977- async with self .limit_pipeline ( ):
1990+ async with self .cmd_pacer ( self . per_cmd_auth ):
19781991 try :
19791992 self ._session = aiohttp .ClientSession (
19801993 conn_timeout = self .connect_timeout ,
@@ -2010,7 +2023,7 @@ async def _rest_gather(self, service_callback, cmd_list, cb_token,
20102023 await service_callback (result , cb_token )
20112024 return
20122025
2013- async with self .limit_pipeline ( ):
2026+ async with self .cmd_pacer ( self . per_cmd_auth ):
20142027 try :
20152028 for cmd in cmd_list :
20162029 url_cmd = f"{ url } ?type=op&cmd={ cmd } &key={ self .api_key } "
0 commit comments