Skip to content

Commit 5ad9db1

Browse files
test: add regression tests
Signed-off-by: Matthias Wende <[email protected]>
1 parent 58a4a86 commit 5ad9db1

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed

tests/test_order_book.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# License: MIT
2+
# Copyright © 2025 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the receive_public_order_book rpc."""
5+
6+
import asyncio
7+
import datetime as dt
8+
from collections.abc import AsyncIterator
9+
from datetime import datetime
10+
from typing import Any
11+
12+
import grpc
13+
import pytest
14+
from frequenz.api.common.v1.grid import delivery_area_pb2, delivery_duration_pb2
15+
from frequenz.api.common.v1.market import power_pb2, price_pb2
16+
from frequenz.api.common.v1.types import decimal_pb2
17+
from frequenz.api.electricity_trading.v1 import (
18+
electricity_trading_pb2,
19+
electricity_trading_pb2_grpc,
20+
)
21+
from google.protobuf import timestamp_pb2
22+
from grpc.aio import ServicerContext
23+
24+
from frequenz.client.electricity_trading import (
25+
Client,
26+
DeliveryArea,
27+
EnergyMarketCodeType,
28+
PublicOrder,
29+
)
30+
31+
START_TIME = datetime.fromisoformat("2023-01-01T12:00:00+00:00")
32+
START_TIME_PB = timestamp_pb2.Timestamp(seconds=1672574400)
33+
CREATE_TIME = datetime.fromisoformat("2023-01-01T12:00:00+00:00")
34+
CREATE_TIME_PB = timestamp_pb2.Timestamp(seconds=1672574400)
35+
MODIFICATION_TIME = datetime.fromisoformat("2023-01-01T12:00:00+00:00")
36+
MODIFICATION_TIME_PB = timestamp_pb2.Timestamp(seconds=1672574400)
37+
38+
39+
class MockElectricityTradingService(
40+
electricity_trading_pb2_grpc.ElectricityTradingServiceServicer
41+
):
42+
"""A mock gRPC service to simulate historic vs real-time streams."""
43+
44+
@staticmethod
45+
def _construct_public_order_book_record(
46+
order_id: int,
47+
) -> electricity_trading_pb2.PublicOrderBookRecord:
48+
return electricity_trading_pb2.PublicOrderBookRecord(
49+
id=order_id,
50+
delivery_area=delivery_area_pb2.DeliveryArea(
51+
code="XYZ",
52+
code_type=delivery_area_pb2.EnergyMarketCodeType.ENERGY_MARKET_CODE_TYPE_EUROPE_EIC,
53+
),
54+
delivery_period=delivery_duration_pb2.DeliveryPeriod(
55+
start=START_TIME_PB,
56+
duration=delivery_duration_pb2.DeliveryDuration.DELIVERY_DURATION_15,
57+
),
58+
type=electricity_trading_pb2.OrderType.ORDER_TYPE_LIMIT,
59+
side=electricity_trading_pb2.MarketSide.MARKET_SIDE_BUY,
60+
price=price_pb2.Price(
61+
amount=decimal_pb2.Decimal(value="100.00"),
62+
currency=price_pb2.Price.Currency.CURRENCY_EUR,
63+
),
64+
quantity=power_pb2.Power(mw=decimal_pb2.Decimal(value="5.00")),
65+
execution_option=(
66+
electricity_trading_pb2.OrderExecutionOption.ORDER_EXECUTION_OPTION_AON
67+
),
68+
create_time=CREATE_TIME_PB,
69+
update_time=MODIFICATION_TIME_PB,
70+
)
71+
72+
def _mark_as_unimplemented(self, context: Any) -> None:
73+
"""Set the context to UNIMPLEMENTED."""
74+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
75+
context.set_details("Method not implemented in the mock servicer.")
76+
77+
async def ReceivePublicOrderBookStream( # pylint: disable=invalid-overridden-method
78+
self,
79+
request: electricity_trading_pb2.ReceivePublicOrderBookStreamRequest,
80+
context: ServicerContext[
81+
electricity_trading_pb2.ReceivePublicOrderBookStreamRequest,
82+
electricity_trading_pb2.ReceivePublicOrderBookStreamResponse,
83+
],
84+
) -> AsyncIterator[electricity_trading_pb2.ReceivePublicOrderBookStreamResponse]:
85+
"""Send different data based on whether start_time is set."""
86+
is_historic = request.HasField("start_time") and request.start_time.seconds > 0
87+
88+
if is_historic:
89+
yield electricity_trading_pb2.ReceivePublicOrderBookStreamResponse(
90+
public_order_book_records=[
91+
self._construct_public_order_book_record(1),
92+
self._construct_public_order_book_record(2),
93+
]
94+
)
95+
return
96+
97+
yield electricity_trading_pb2.ReceivePublicOrderBookStreamResponse(
98+
public_order_book_records=[
99+
self._construct_public_order_book_record(9),
100+
]
101+
)
102+
await asyncio.sleep(5)
103+
104+
# --- Placeholder implementations for ALL other abstract methods ---
105+
106+
async def CancelAllGridpoolOrders( # pylint: disable=invalid-overridden-method
107+
self, request: Any, context: Any
108+
) -> electricity_trading_pb2.CancelAllGridpoolOrdersResponse:
109+
"""Handle CancelAllGridpoolOrders request."""
110+
self._mark_as_unimplemented(context)
111+
return electricity_trading_pb2.CancelAllGridpoolOrdersResponse()
112+
113+
async def CancelGridpoolOrder( # pylint: disable=invalid-overridden-method
114+
self, request: Any, context: Any
115+
) -> electricity_trading_pb2.CancelGridpoolOrderResponse:
116+
"""Handle CancelGridpoolOrder request."""
117+
self._mark_as_unimplemented(context)
118+
return electricity_trading_pb2.CancelGridpoolOrderResponse()
119+
120+
async def CreateGridpoolOrder( # pylint: disable=invalid-overridden-method
121+
self, request: Any, context: Any
122+
) -> electricity_trading_pb2.CreateGridpoolOrderResponse:
123+
"""Handle CreateGridpoolOrder request."""
124+
self._mark_as_unimplemented(context)
125+
return electricity_trading_pb2.CreateGridpoolOrderResponse()
126+
127+
async def GetGridpoolOrder( # pylint: disable=invalid-overridden-method
128+
self, request: Any, context: Any
129+
) -> electricity_trading_pb2.GetGridpoolOrderResponse:
130+
"""Handle GetGridpoolOrder request."""
131+
self._mark_as_unimplemented(context)
132+
return electricity_trading_pb2.GetGridpoolOrderResponse()
133+
134+
async def ListGridpoolOrders( # pylint: disable=invalid-overridden-method
135+
self, request: Any, context: Any
136+
) -> electricity_trading_pb2.ListGridpoolOrdersResponse:
137+
"""Handle ListGridpoolOrders request."""
138+
self._mark_as_unimplemented(context)
139+
return electricity_trading_pb2.ListGridpoolOrdersResponse()
140+
141+
async def ListGridpoolTrades( # pylint: disable=invalid-overridden-method
142+
self, request: Any, context: Any
143+
) -> electricity_trading_pb2.ListGridpoolTradesResponse:
144+
"""Handle ListGridpoolTrades request."""
145+
self._mark_as_unimplemented(context)
146+
return electricity_trading_pb2.ListGridpoolTradesResponse()
147+
148+
async def ReceiveGridpoolOrdersStream( # pylint: disable=invalid-overridden-method
149+
self, request: Any, context: Any
150+
) -> AsyncIterator[electricity_trading_pb2.ReceiveGridpoolOrdersStreamResponse]:
151+
"""Handle ReceiveGridpoolOrdersStream request."""
152+
self._mark_as_unimplemented(context)
153+
# The if False: part ensures that the yield statement is never actually executed.
154+
# The result is an empty—asynchronous generator that satisfies the type system and
155+
# the gRPC framework's requirements for an unimplemented streaming method.
156+
if False: # pylint: disable=using-constant-test
157+
yield
158+
159+
async def ReceiveGridpoolTradesStream( # pylint: disable=invalid-overridden-method
160+
self, request: Any, context: Any
161+
) -> AsyncIterator[electricity_trading_pb2.ReceiveGridpoolTradesStreamResponse]:
162+
"""Handle ReceiveGridpoolTradesStream request."""
163+
self._mark_as_unimplemented(context)
164+
if False: # pylint: disable=using-constant-test
165+
yield
166+
167+
async def ReceivePublicTradesStream( # pylint: disable=invalid-overridden-method
168+
self, request: Any, context: Any
169+
) -> AsyncIterator[electricity_trading_pb2.ReceivePublicTradesStreamResponse]:
170+
"""Handle ReceivePublicTradesStream request."""
171+
self._mark_as_unimplemented(context)
172+
if False: # pylint: disable=using-constant-test
173+
yield
174+
175+
async def UpdateGridpoolOrder( # pylint: disable=invalid-overridden-method
176+
self, request: Any, context: Any
177+
) -> electricity_trading_pb2.UpdateGridpoolOrderResponse:
178+
"""Handle UpdateGridpoolOrder request."""
179+
self._mark_as_unimplemented(context)
180+
return electricity_trading_pb2.UpdateGridpoolOrderResponse()
181+
182+
183+
@pytest.fixture
184+
async def mock_server() -> AsyncIterator[str]:
185+
"""Set up and tear down a mock gRPC server for the test session."""
186+
servicer = MockElectricityTradingService()
187+
server = grpc.aio.server()
188+
electricity_trading_pb2_grpc.add_ElectricityTradingServiceServicer_to_server(
189+
servicer, server
190+
)
191+
port = server.add_insecure_port("[::]:0")
192+
address = f"[::1]:{port}"
193+
await server.start()
194+
try:
195+
yield address
196+
finally:
197+
await server.stop(0)
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_concurrent_historic_and_realtime_streams(mock_server: str) -> None:
202+
"""Verify that historic and real-time streams from one client instance are distinct."""
203+
client = Client(server_url=f"grpc://{mock_server}?ssl=false")
204+
205+
delivery_area = DeliveryArea(
206+
code="DE-TENNET", code_type=EnergyMarketCodeType.EUROPE_EIC
207+
)
208+
209+
end_time = dt.datetime.now(dt.timezone.utc)
210+
start_time = end_time - dt.timedelta(hours=1)
211+
212+
historic_stream = client.receive_public_order_book(
213+
delivery_area=delivery_area, start_time=start_time, end_time=end_time
214+
)
215+
realtime_stream = client.receive_public_order_book(delivery_area=delivery_area)
216+
217+
historic_orders_received: list[PublicOrder] = []
218+
realtime_orders_received: list[PublicOrder] = []
219+
220+
async def consume_historic() -> None:
221+
"""Consume all items from the historic stream."""
222+
async for batch in historic_stream.new_receiver():
223+
historic_orders_received.extend(batch)
224+
225+
async def consume_realtime() -> None:
226+
"""Consume the first item from the real-time stream."""
227+
async for batch in realtime_stream.new_receiver():
228+
realtime_orders_received.extend(batch)
229+
break
230+
231+
try:
232+
await asyncio.wait_for(
233+
asyncio.gather(consume_historic(), consume_realtime()), timeout=2.0
234+
)
235+
except asyncio.TimeoutError:
236+
pytest.fail(
237+
"Test timed out. The streams did not produce the expected data in time."
238+
)
239+
240+
assert (
241+
len(historic_orders_received) == 2
242+
), "Historic stream should receive a batch of 2"
243+
assert [order.public_order_id for order in historic_orders_received] == [1, 2]
244+
245+
assert len(realtime_orders_received) == 1, "Real-time stream should receive 1 item"
246+
assert realtime_orders_received[0].public_order_id == 9

0 commit comments

Comments
 (0)