| 
3 | 3 | 
 
  | 
4 | 4 | """Management of configuration."""  | 
5 | 5 | 
 
  | 
 | 6 | +import asyncio  | 
6 | 7 | import logging  | 
7 | 8 | import pathlib  | 
8 | 9 | from collections.abc import Mapping, Sequence  | 
 | 10 | +from dataclasses import is_dataclass  | 
9 | 11 | from datetime import timedelta  | 
10 |  | -from typing import Any, Final  | 
 | 12 | +from typing import Any, Final, Literal, TypeGuard, overload  | 
11 | 13 | 
 
  | 
12 | 14 | import marshmallow  | 
13 |  | -from frequenz.channels import Broadcast, Receiver  | 
 | 15 | +from frequenz.channels import Broadcast, Receiver, ReceiverStoppedError  | 
14 | 16 | from frequenz.channels.experimental import WithPrevious  | 
15 | 17 | from marshmallow import Schema, ValidationError  | 
16 | 18 | from typing_extensions import override  | 
@@ -313,6 +315,99 @@ def new_receiver(  # pylint: disable=too-many-arguments  | 
313 | 315 |         return receiver  | 
314 | 316 | 
 
  | 
315 | 317 | 
 
  | 
 | 318 | +@overload  | 
 | 319 | +async def wait_for_first(  | 
 | 320 | +    receiver: Receiver[DataclassT | Exception | None],  | 
 | 321 | +    /,  | 
 | 322 | +    *,  | 
 | 323 | +    receiver_name: str | None = None,  | 
 | 324 | +    allow_none: Literal[False] = False,  | 
 | 325 | +    timeout: timedelta = timedelta(minutes=1),  | 
 | 326 | +) -> DataclassT: ...  | 
 | 327 | + | 
 | 328 | + | 
 | 329 | +@overload  | 
 | 330 | +async def wait_for_first(  | 
 | 331 | +    receiver: Receiver[DataclassT | Exception | None],  | 
 | 332 | +    /,  | 
 | 333 | +    *,  | 
 | 334 | +    receiver_name: str | None = None,  | 
 | 335 | +    allow_none: Literal[True] = True,  | 
 | 336 | +    timeout: timedelta = timedelta(minutes=1),  | 
 | 337 | +) -> DataclassT | None: ...  | 
 | 338 | + | 
 | 339 | + | 
 | 340 | +async def wait_for_first(  | 
 | 341 | +    receiver: Receiver[DataclassT | Exception | None],  | 
 | 342 | +    /,  | 
 | 343 | +    *,  | 
 | 344 | +    receiver_name: str | None = None,  | 
 | 345 | +    allow_none: bool = False,  | 
 | 346 | +    timeout: timedelta = timedelta(minutes=1),  | 
 | 347 | +) -> DataclassT | None:  | 
 | 348 | +    """Receive the first configuration.  | 
 | 349 | +
  | 
 | 350 | +    Args:  | 
 | 351 | +        receiver: The receiver to receive the first configuration from.  | 
 | 352 | +        receiver_name: The name of the receiver, used for logging. If `None`, the  | 
 | 353 | +            string representation of the receiver will be used.  | 
 | 354 | +        allow_none: Whether consider a `None` value as a valid configuration.  | 
 | 355 | +        timeout: The timeout in seconds to wait for the first configuration.  | 
 | 356 | +
  | 
 | 357 | +    Returns:  | 
 | 358 | +        The first configuration received.  | 
 | 359 | +
  | 
 | 360 | +    Raises:  | 
 | 361 | +        asyncio.TimeoutError: If the first configuration is not received within the  | 
 | 362 | +            timeout.  | 
 | 363 | +        ReceiverStoppedError: If the receiver is stopped before the first configuration  | 
 | 364 | +            is received.  | 
 | 365 | +    """  | 
 | 366 | +    if receiver_name is None:  | 
 | 367 | +        receiver_name = str(receiver)  | 
 | 368 | + | 
 | 369 | +    # We need this type guard because we can't use a TypeVar for isinstance checks or  | 
 | 370 | +    # match cases.  | 
 | 371 | +    def is_config_class(value: DataclassT | Exception | None) -> TypeGuard[DataclassT]:  | 
 | 372 | +        return is_dataclass(value) if value is not None else False  | 
 | 373 | + | 
 | 374 | +    _logger.info(  | 
 | 375 | +        "%s: Waiting %s seconds for the first configuration to arrive...",  | 
 | 376 | +        receiver_name,  | 
 | 377 | +        timeout.total_seconds(),  | 
 | 378 | +    )  | 
 | 379 | +    try:  | 
 | 380 | +        async with asyncio.timeout(timeout.total_seconds()):  | 
 | 381 | +            async for config in receiver:  | 
 | 382 | +                match config:  | 
 | 383 | +                    case None:  | 
 | 384 | +                        if allow_none:  | 
 | 385 | +                            return None  | 
 | 386 | +                        _logger.error(  | 
 | 387 | +                            "%s: Received empty configuration, waiting again for "  | 
 | 388 | +                            "a first configuration to be set.",  | 
 | 389 | +                            receiver_name,  | 
 | 390 | +                        )  | 
 | 391 | +                    case Exception() as error:  | 
 | 392 | +                        _logger.error(  | 
 | 393 | +                            "%s: Error while receiving the first configuration, "  | 
 | 394 | +                            "will keep waiting for an update: %s.",  | 
 | 395 | +                            receiver_name,  | 
 | 396 | +                            error,  | 
 | 397 | +                        )  | 
 | 398 | +                    case config if is_config_class(config):  | 
 | 399 | +                        _logger.info("%s: Received first configuration.", receiver_name)  | 
 | 400 | +                        return config  | 
 | 401 | +                    case unexpected:  | 
 | 402 | +                        assert (  | 
 | 403 | +                            False  | 
 | 404 | +                        ), f"{receiver_name}: Unexpected value received: {unexpected!r}."  | 
 | 405 | +    except asyncio.TimeoutError:  | 
 | 406 | +        _logger.error("%s: No configuration received in time.", receiver_name)  | 
 | 407 | +        raise  | 
 | 408 | +    raise ReceiverStoppedError(receiver)  | 
 | 409 | + | 
 | 410 | + | 
316 | 411 | def _not_equal_with_logging(  | 
317 | 412 |     *,  | 
318 | 413 |     key: str | Sequence[str],  | 
 | 
0 commit comments