Skip to content
106 changes: 59 additions & 47 deletions python-avd/pyavd/_cv/client/async_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,59 +121,64 @@
max_retries (int): Maximum number of retry attempts for Status.UNAVAILABLE. Total attempts = 1 + max_retries.
initial_delay (int): Initial delay in seconds before the first retry.
factor (int): Multiplier for the delay in subsequent retries.
list_field (str): Name of the parameter to be split if Status.RESOURCE_EXHAUSTED is received.
iter_field (str): Name of the parameter to be split if Status.RESOURCE_EXHAUSTED is received.
min_items_for_splitting_attempt (int): Minimum length of the item that we'll still try to split.
"""

max_retries: int
initial_delay: int
factor: int
list_field: str | None
iter_field: str | None
min_items_for_splitting_attempt: int
func: Callable
func_signature: Signature
bound_arguments: BoundArguments
current_arguments_dict: dict
permitted_field_types: ClassVar[list[type]] = [list, set, tuple]

def __init__(
self,
max_retries: int = 5,
initial_delay: int = 1,
factor: int = 2,
list_field: str | None = None,
iter_field: str | None = None,
min_items_for_splitting_attempt: int = 2,
) -> None:
self.max_retries = max_retries
self.initial_delay = initial_delay
self.factor = factor
self.list_field = list_field
self.iter_field = iter_field
self.min_items_for_splitting_attempt = max(2, min_items_for_splitting_attempt)

def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
self.func = func
self.func_signature = signature(func)

if self.list_field:
if not (return_annotation := self._is_list_annotation(self.func_signature.return_annotation, strict=True))[0]:
if self.iter_field:
if not (return_annotation := self._is_sized_iterable_annotation(self.func_signature.return_annotation, strict=True, allowed_types=[list]))[0]:
msg = (
f"GRPCRequestHandler decorator is unable to bind to the function '{func.__name__}' with the 'list_field' argument. "
f"GRPCRequestHandler decorator is unable to bind to the function '{func.__name__}' with the 'iter_field' argument. "
f"Expected a return type of 'list'. Got '{return_annotation[1]}'."
)
raise TypeError(msg)

# Verify that `self.list_field` is listed in parameters of the decorated function
if self.list_field not in (func_parameters := self.func_signature.parameters.keys()):
# Verify that `self.iter_field` is listed in parameters of the decorated function
if self.iter_field not in (func_parameters := self.func_signature.parameters.keys()):
msg = (
f"{self.__class__.__name__} decorator is unable to find the list_field '{self.list_field}' "
f"{self.__class__.__name__} decorator is unable to find the iter_field '{self.iter_field}' "
f"in the given arguments to '{self.func.__name__}'. Found: '{list(func_parameters)}'."
)
raise KeyError(msg)

# Verify that annotation of `self.list_field` is a `list` (or a `UnionType` with `list` being one of the arguments)
if not (list_field_annotation := self._is_list_annotation(self.func_signature.parameters[self.list_field].annotation))[0]:
# Verify that annotation of `self.iter_field` is `list`, `set`, `tuple` or a `UnionType` with `list`, `set` or `tuple` being one of the arguments
if not (
iter_field_annotation := self._is_sized_iterable_annotation(
self.func_signature.parameters[self.iter_field].annotation, allowed_types=self.permitted_field_types
)
)[0]:
msg = (
f"{self.__class__.__name__} decorator expected the type of the list_field '{self.list_field}' in function '{self.func.__name__}' "
f"to be defined as a list. Got '{list_field_annotation[1]}' (type '{type(list_field_annotation[1])}')."
f"{self.__class__.__name__} decorator expected the type of the iter_field '{self.iter_field}' in function '{self.func.__name__}' "
f"to be defined as a list, set or tuple. Got '{iter_field_annotation[1]}' (type '{type(iter_field_annotation[1])}')."
)
raise TypeError(msg)

Expand All @@ -184,22 +189,23 @@
return wrapper

@staticmethod
def _is_list_annotation(annotation: Any, strict: bool = False) -> tuple[bool, Any]:
def _is_sized_iterable_annotation(annotation: Any, allowed_types: list[type], strict: bool = False) -> tuple[bool, Any]:
"""
Check if provided annotation is a `list`.
Check if provided annotation matches any type specified in `allowed_types`.

Default `strict: False` will also match 'types.UnionType' with included `list`.
Default `strict: False` will also match 'allowed_types.UnionType'.
"""
_string_based_annotation = (
list
if (
(isinstance(annotation, str) and annotation.startswith("list"))
or (not strict and get_origin(annotation) is UnionType and any(get_origin(arg) is list for arg in get_args(annotation)))
)
else annotation
)

return _string_based_annotation is list or get_origin(annotation) is list, _string_based_annotation
_string_based_annotation = None
for permitted_type in allowed_types:
if (isinstance(annotation, str) and annotation.startswith(permitted_type.__name__)) or (
not strict and get_origin(annotation) is UnionType and any(get_origin(arg) in allowed_types for arg in get_args(annotation))
):
_string_based_annotation = permitted_type
break
if _string_based_annotation is None:
_string_based_annotation = annotation

return _string_based_annotation in allowed_types or get_origin(annotation) in allowed_types, _string_based_annotation

async def _execute_single_call_with_retries(self, call_args: tuple, call_kwargs: dict) -> None:
"""Executes a single call to self.func with retry logic for gRPC UNAVAILABLE."""
Expand Down Expand Up @@ -257,27 +263,27 @@
# Required by ruff
return None

async def _execute_with_splitting(self, original_call_args: tuple, original_call_kwargs: dict) -> Any:

Check failure on line 266 in python-avd/pyavd/_cv/client/async_decorators.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 17 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=aristanetworks_avd&issues=AZrEOZboZNLDOeseEG4m&open=AZrEOZboZNLDOeseEG4m&pullRequest=6196
func_name = self.func.__name__

if not (self.list_field and self.func_signature):
# No list_field configured for splitting, execute the call directly (with retries)
if not (self.iter_field and self.func_signature):
# No iter_field configured for splitting, execute the call directly (with retries)
return await self._execute_single_call_with_retries(original_call_args, original_call_kwargs)

bound_arguments = self.func_signature.bind(*original_call_args, **original_call_kwargs)
current_arguments_dict = bound_arguments.arguments

list_value: list = current_arguments_dict.get(self.list_field, [])
if not isinstance(list_value, list):
iter_value: list | set | tuple = current_arguments_dict.get(self.iter_field, [])
if not any(isinstance(iter_value, allowed_type) for allowed_type in self.permitted_field_types) or not iter_value:
msg = (
f"{self.__class__.__name__} decorator expected the value of the list_field '{self.list_field}' for function '{func_name}' "
f"to be a list. Got '{type(list_value)}'."
f"{self.__class__.__name__} decorator expected the value of the iter_field '{self.iter_field}' for function '{func_name}' "
f"to be a non-empty list, set or tuple. Got '{iter_value}' of a type '{type(iter_value)}'."
)
raise TypeError(msg)

LOGGER.debug("%s: Preparing call for '%s' for list_field '%s' with %s item(s).", self.__class__.__name__, func_name, self.list_field, len(list_value))
LOGGER.debug("%s: Preparing call for '%s' for iter_field '%s' with %s item(s).", self.__class__.__name__, func_name, self.iter_field, len(iter_value))

if len(list_value) < self.min_items_for_splitting_attempt:
if len(iter_value) < self.min_items_for_splitting_attempt:
# No need to try/except if we cannot split the list.
return await self._execute_single_call_with_retries(original_call_args, original_call_kwargs)

Expand All @@ -288,44 +294,50 @@
# At minimum try to split in two.
# The double negatives make // round up instead of down.
ratio = max(2, -(-e.size // e.max_size))
chunk_size = len(list_value) // ratio
chunk_size = len(iter_value) // ratio
LOGGER.info(
"%s: Message size %s exceeded the max of %s for '%s' on list_field '%s'. Attempting to split %s items.",
"%s: Message size %s exceeded the max of %s for '%s' on iter_field '%s'. Attempting to split %s items.",
self.__class__.__name__,
e.size,
e.max_size,
func_name,
self.list_field,
len(list_value),
self.iter_field,
len(iter_value),
)
# Use case where ratio is too high leading to the chuck_size being calculated as zero
if chunk_size == 0 and len(list_value) > 0:
if chunk_size == 0 and len(iter_value) > 0:
chunk_size = 1

planned_attempts_qty = int((len(list_value) / chunk_size) + (1 if len(list_value) % chunk_size else 0))
planned_attempts_qty = int((len(iter_value) / chunk_size) + (1 if len(iter_value) % chunk_size else 0))

LOGGER.info(
"%s: Splitting list_field '%s' for '%s' into %s calls with up to %s items each.",
"%s: Splitting iter_field '%s' for '%s' into %s calls with up to %s items each.",
self.__class__.__name__,
self.list_field,
self.iter_field,
func_name,
planned_attempts_qty,
chunk_size,
)

# For every chunk we call ourselves recursively, so we can catch any further needs of splitting.
aggregated_results = []
for chunk_id, chunk in enumerate(batch(list_value, chunk_size)):
# Identify type of the iterable
iter_type = list
if isinstance(iter_value, set):
iter_type = set
if isinstance(iter_value, tuple):
iter_type = tuple
for chunk_id, chunk in enumerate(batch(iter_value, chunk_size, iter_type)):
LOGGER.info(
"%s: Processing chunk %s/%s for '%s' with %s item(s) from list_field '%s'.",
"%s: Processing chunk %s/%s for '%s' with %s item(s) from iter_field '%s'.",
self.__class__.__name__,
chunk_id + 1,
planned_attempts_qty,
func_name,
len(chunk),
self.list_field,
self.iter_field,
)
current_arguments_dict[self.list_field] = chunk
current_arguments_dict[self.iter_field] = chunk

aggregated_results.extend(await self._execute_with_splitting(bound_arguments.args, bound_arguments.kwargs))

Expand Down
8 changes: 4 additions & 4 deletions python-avd/pyavd/_cv/client/configlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def set_configlet_container(
return response.value

@LimitCvVersion(min_ver="2024.2.0")
@GRPCRequestHandler(list_field="containers")
@GRPCRequestHandler(iter_field="containers")
async def set_configlet_containers(
self: CVClientProtocol,
workspace_id: str,
Expand Down Expand Up @@ -248,7 +248,7 @@ async def delete_configlet_container(

return response.value

@GRPCRequestHandler(list_field="configlet_ids")
@GRPCRequestHandler(iter_field="configlet_ids")
async def get_configlets(
self: CVClientProtocol,
workspace_id: str,
Expand Down Expand Up @@ -358,7 +358,7 @@ async def set_configlet_from_file(
return response.value

@LimitCvVersion(min_ver="2024.2.0")
@GRPCRequestHandler(list_field="configlets")
@GRPCRequestHandler(iter_field="configlets")
async def set_configlets_from_files(
self: CVClientProtocol,
workspace_id: str,
Expand Down Expand Up @@ -432,7 +432,7 @@ async def set_configlets_from_files( # noqa: F811 - Redefining with decorator.
LOGGER.info("set_configlets_from_files: Batch %s", index)
configlet_configs.extend(await gather(*batch_coroutines))

@GRPCRequestHandler(list_field="configlet_ids")
@GRPCRequestHandler(iter_field="configlet_ids")
async def delete_configlets(
self: CVClientProtocol,
workspace_id: str,
Expand Down
2 changes: 1 addition & 1 deletion python-avd/pyavd/_cv/client/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ async def get_topology_studio_inputs(
)
return topology_inputs

@GRPCRequestHandler()
@GRPCRequestHandler(iter_field="device_inputs")
async def set_topology_studio_inputs(
self: CVClientProtocol,
workspace_id: str,
Expand Down
6 changes: 3 additions & 3 deletions python-avd/pyavd/_cv/client/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def get_tags(

return tags

@GRPCRequestHandler()
@GRPCRequestHandler(iter_field="tags")
async def set_tags(
self: CVClientProtocol,
workspace_id: str,
Expand Down Expand Up @@ -218,7 +218,7 @@ async def get_tag_assignments(

return tag_assignments

@GRPCRequestHandler()
@GRPCRequestHandler(iter_field="tag_assignments")
async def set_tag_assignments(
self: CVClientProtocol,
workspace_id: str,
Expand Down Expand Up @@ -256,7 +256,7 @@ async def set_tag_assignments(

return [response.key async for response in responses]

@GRPCRequestHandler()
@GRPCRequestHandler(iter_field="tag_assignments")
async def delete_tag_assignments(
self: CVClientProtocol,
workspace_id: str,
Expand Down
12 changes: 8 additions & 4 deletions python-avd/pyavd/_utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from __future__ import annotations

from itertools import islice
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from collections.abc import Generator, Iterable


def batch(iterable: Iterable, size: int) -> Generator[Iterable]:
"""Returns a Generator of lists containing 'size' items. The final list may be shorter."""
T = TypeVar("T")
Batch_Type = type[list] | type[set] | type[tuple]


def batch(iterable: Iterable[T], size: int, batch_type: Batch_Type = list) -> Generator[list[T] | set[T] | tuple[T]]:
"""Returns a Generator of lists, sets or tuples containing 'size' items. The final yielded iterator may be shorter."""
iterator = iter(iterable)
while batch := list(islice(iterator, size)):
yield batch
yield batch_type(batch)
Loading