Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"frequenz-api-common >= 0.6.3, < 0.7.0",
"grpcio >= 1.66.2, < 2",
"frequenz-channels >= 1.0.0, < 2",
"frequenz-client-base >= 0.7.0, < 0.8.0",
"frequenz-client-base >= 0.8.0, < 0.9.0",
"frequenz-client-common >= 0.1.0, < 0.3.0",
"frequenz-api-electricity-trading >= 0.2.4, < 1",
"protobuf >= 5.28.0, < 6",
Expand Down
53 changes: 17 additions & 36 deletions src/frequenz/client/electricity_trading/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import logging
from datetime import datetime, timezone
from decimal import Decimal, InvalidOperation
from typing import TYPE_CHECKING, Any, Awaitable, cast
from typing import Awaitable, cast

import grpc

# pylint: disable=no-member
from frequenz.api.electricity_trading.v1 import electricity_trading_pb2
from frequenz.api.electricity_trading.v1 import (
electricity_trading_pb2,
electricity_trading_pb2_grpc,
)
from frequenz.api.electricity_trading.v1.electricity_trading_pb2_grpc import (
ElectricityTradingServiceStub,
)
Expand All @@ -23,7 +26,6 @@
from frequenz.client.base.streaming import GrpcStreamBroadcaster
from frequenz.client.common.pagination import Params
from google.protobuf import field_mask_pb2, struct_pb2
from typing_extensions import override

from ._types import (
DeliveryArea,
Expand All @@ -45,12 +47,6 @@
UpdateOrder,
)

if TYPE_CHECKING:
from frequenz.api.electricity_trading.v1.electricity_trading_pb2_grpc import (
ElectricityTradingServiceAsyncStub,
)


_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -91,7 +87,7 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N
) from exc


class Client(BaseApiClient):
class Client(BaseApiClient[ElectricityTradingServiceStub]):
"""Electricity trading client."""

_instances: dict[tuple[str, str | None], "Client"] = {}
Expand Down Expand Up @@ -133,10 +129,11 @@ def __init__(
if not hasattr(
self, "_initialized"
): # Prevent re-initialization of existing instances
super().__init__(server_url, connect=connect)
self._stub: ElectricityTradingServiceAsyncStub | None = None
if connect:
self._create_stub()
super().__init__(
server_url,
connect=connect,
create_stub=ElectricityTradingServiceStub,
)
self._initialized = True

self._gridpool_orders_streams: dict[
Expand All @@ -162,28 +159,8 @@ def __init__(

self._metadata = (("key", auth_key),) if auth_key else ()

def _create_stub(self) -> None:
"""Create a new gRPC stub for the Electricity Trading service."""
stub: Any = ElectricityTradingServiceStub(self.channel)
self._stub = stub

@override
def connect(self, server_url: str | None = None) -> None:
"""Connect to the server, possibly using a new URL.

If the client is already connected and the URL is the same as the previous URL,
this method does nothing. If you want to force a reconnection, you can call
[disconnect()][frequenz.client.base.client.BaseApiClient.disconnect] first.

Args:
server_url: The URL of the server to connect to. If not provided, the
previously used URL is used.
"""
super().connect(server_url)
self._create_stub()

@property
def stub(self) -> ElectricityTradingServiceAsyncStub:
def stub(self) -> electricity_trading_pb2_grpc.ElectricityTradingServiceAsyncStub:
"""
Get the gRPC stub for the Electricity Trading service.

Expand All @@ -195,7 +172,11 @@ def stub(self) -> ElectricityTradingServiceAsyncStub:
"""
if self._stub is None:
raise ClientNotConnected(server_url=self.server_url, operation="stub")
return self._stub
# This type: ignore is needed because we need to cast the sync stub to
# the async stub, but we can't use cast because the async stub doesn't
# actually exists to the eyes of the interpreter, it only exists for the
# type-checker, so it can only be used for type hints.
return self._stub # type: ignore

async def stream_gridpool_orders(
# pylint: disable=too-many-arguments, too-many-positional-arguments
Expand Down
Loading