Skip to content

Commit 13b792e

Browse files
committed
feat(mrp): add fire-and-forget remote option and clearer reconnect errors
Add an opt-in fire-and-forget path for MRP remote commands while keeping defaults unchanged; allow overriding command timeouts. Treat closed transports during send as ConnectionLostError with an explicit reconnect hint and convert response timeouts to OperationTimeoutError. Cover the new behaviors in mrp protocol/interface tests.
1 parent 91a5489 commit 13b792e

File tree

5 files changed

+196
-25
lines changed

5 files changed

+196
-25
lines changed

pyatv/protocols/mrp/__init__.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
_LOGGER = logging.getLogger(__name__)
7373

7474
_DEFAULT_SKIP_TIME = 15
75+
_DEFAULT_COMMAND_TIMEOUT = 5.0
7576

7677
# Source: https://github.com/Daij-Djan/DDHidLib/blob/master/usb_hid_usages.txt
7778
_KEY_LOOKUP = {
@@ -338,8 +339,24 @@ def __init__(
338339
self.psm = psm
339340
self.protocol = protocol
340341

341-
async def _send_command(self, command, **kwargs):
342-
resp = await self.protocol.send_and_receive(messages.command(command, **kwargs))
342+
async def _send_command(
343+
self,
344+
command,
345+
*,
346+
wait_for_response: bool = True,
347+
timeout: Optional[float] = None,
348+
**kwargs,
349+
):
350+
message = messages.command(command, **kwargs)
351+
if not wait_for_response:
352+
await self.protocol.send(message)
353+
return
354+
355+
resp = await self.protocol.send_and_receive(
356+
message,
357+
timeout=_DEFAULT_COMMAND_TIMEOUT if timeout is None else timeout,
358+
wait_for_response=True,
359+
)
343360
inner = resp.inner()
344361

345362
if inner.sendError == protobuf.SendError.NoError:
@@ -352,6 +369,19 @@ async def _send_command(self, command, **kwargs):
352369
f"{protobuf.HandlerReturnStatus.Enum.Name(inner.handlerReturnStatus)}"
353370
)
354371

372+
async def send_command(
373+
self,
374+
command,
375+
*,
376+
wait_for_response: bool = True,
377+
timeout: Optional[float] = None,
378+
**kwargs,
379+
) -> None:
380+
"""Send an arbitrary command with optional response handling overrides."""
381+
await self._send_command(
382+
command, wait_for_response=wait_for_response, timeout=timeout, **kwargs
383+
)
384+
355385
async def up(self, action: InputAction = InputAction.SingleTap) -> None:
356386
"""Press key up."""
357387
await _send_hid_key(self.protocol, "up", action)

pyatv/protocols/mrp/connection.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,18 @@ def send(self, message: protobuf.ProtocolMessage) -> None:
121121
log_binary(_LOGGER, self._log_str + ">> Send", Encrypted=serialized)
122122

123123
data = write_variant(len(serialized)) + serialized
124-
self._transport.write(data)
124+
if self._transport is None:
125+
raise exceptions.ConnectionLostError(
126+
"connection is closed; reconnect required"
127+
)
128+
129+
try:
130+
self._transport.write(data)
131+
except (ConnectionResetError, BrokenPipeError, OSError) as ex:
132+
self._transport = None
133+
raise exceptions.ConnectionLostError(
134+
"connection was lost while sending; reconnect required"
135+
) from ex
125136
log_protobuf(_LOGGER, self._log_str + ">> Send: Protobuf", message)
126137

127138
def send_raw(self, data):
@@ -132,7 +143,18 @@ def send_raw(self, data):
132143
log_binary(_LOGGER, self._log_str + ">> Send raw", Encrypted=data)
133144

134145
data = write_variant(len(data)) + data
135-
self._transport.write(data)
146+
if self._transport is None:
147+
raise exceptions.ConnectionLostError(
148+
"connection is closed; reconnect required"
149+
)
150+
151+
try:
152+
self._transport.write(data)
153+
except (ConnectionResetError, BrokenPipeError, OSError) as ex:
154+
self._transport = None
155+
raise exceptions.ConnectionLostError(
156+
"connection was lost while sending; reconnect required"
157+
) from ex
136158

137159
def data_received(self, data):
138160
"""Message was received from device."""

pyatv/protocols/mrp/protocol.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -222,28 +222,47 @@ async def _enable_encryption(self) -> None:
222222
)
223223
self.connection.enable_encryption(output_key, input_key)
224224

225+
def _ensure_send_possible(self) -> None:
226+
if self._state in [ProtocolState.CONNECTED, ProtocolState.READY]:
227+
if not self.connection.connected:
228+
raise exceptions.ConnectionLostError(
229+
"connection is closed; reconnect required"
230+
)
231+
return
232+
233+
if self._state == ProtocolState.STOPPED:
234+
raise exceptions.ConnectionLostError(
235+
"connection is closed; reconnect required"
236+
)
237+
238+
raise exceptions.InvalidStateError(self._state.name)
239+
225240
async def send(self, message: protobuf.ProtocolMessage) -> None:
226241
"""Send a message and expect no response."""
227-
if self._state not in [
228-
ProtocolState.CONNECTED,
229-
ProtocolState.READY,
230-
]:
231-
raise exceptions.InvalidStateError(self._state.name)
242+
self._ensure_send_possible()
232243

233-
self.connection.send(message)
244+
try:
245+
self.connection.send(message)
246+
except exceptions.ConnectionLostError:
247+
raise
248+
except Exception as ex:
249+
raise exceptions.ConnectionLostError(
250+
"connection was lost while sending; reconnect required"
251+
) from ex
234252

235253
async def send_and_receive(
236254
self,
237255
message: protobuf.ProtocolMessage,
238256
generate_identifier: bool = True,
239-
timeout: float = 5.0,
240-
) -> protobuf.ProtocolMessage:
257+
timeout: Optional[float] = 5.0,
258+
wait_for_response: bool = True,
259+
) -> Optional[protobuf.ProtocolMessage]:
241260
"""Send a message and wait for a response."""
242-
if self._state not in [
243-
ProtocolState.CONNECTED,
244-
ProtocolState.READY,
245-
]:
246-
raise exceptions.InvalidStateError(self._state.name)
261+
self._ensure_send_possible()
262+
263+
if not wait_for_response:
264+
await self.send(message)
265+
return None
247266

248267
# Some messages will respond with the same identifier as used in the
249268
# corresponding request. Others will not and one example is the crypto
@@ -258,11 +277,18 @@ async def send_and_receive(
258277
else:
259278
identifier = "type_" + str(message.type)
260279

261-
self.connection.send(message)
280+
try:
281+
self.connection.send(message)
282+
except exceptions.ConnectionLostError:
283+
raise
284+
except Exception as ex:
285+
raise exceptions.ConnectionLostError(
286+
"connection was lost while sending; reconnect required"
287+
) from ex
262288
return await self._receive(identifier, timeout)
263289

264290
async def _receive(
265-
self, identifier: str, timeout: float
291+
self, identifier: str, timeout: Optional[float]
266292
) -> protobuf.ProtocolMessage:
267293
semaphore = asyncio.Semaphore(value=0)
268294
self._outstanding[identifier] = OutstandingMessage(
@@ -271,15 +297,28 @@ async def _receive(
271297

272298
try:
273299
# The connection instance will dispatch the message
274-
async with async_timeout.timeout(timeout):
300+
if timeout is None:
275301
await semaphore.acquire()
302+
else:
303+
async with async_timeout.timeout(timeout):
304+
await semaphore.acquire()
276305

306+
except asyncio.TimeoutError as ex:
307+
del self._outstanding[identifier]
308+
raise exceptions.OperationTimeoutError(
309+
f"no response received within {timeout} seconds"
310+
) from ex
277311
except Exception:
278312
del self._outstanding[identifier]
279313
raise
280314

281-
response = self._outstanding[identifier].response
282-
del self._outstanding[identifier]
315+
outstanding = self._outstanding.pop(identifier, None)
316+
if outstanding is None:
317+
raise exceptions.ConnectionLostError(
318+
"connection closed before response was received; reconnect required"
319+
)
320+
321+
response = outstanding.response
283322
return response
284323

285324
def message_received(self, message: protobuf.ProtocolMessage, _) -> None:

tests/protocols/mrp/test_mrp_interface.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
"""Unit tests for interface implementations in pyatv.protocols.mrp."""
22

3+
import asyncio
34
import datetime
45
import math
56
from typing import Any, Dict
6-
from unittest.mock import Mock, PropertyMock
7+
from unittest.mock import AsyncMock, Mock, PropertyMock
78

89
import pytest
910

1011
from pyatv import exceptions
1112
from pyatv.core import UpdatedState
1213
from pyatv.interface import ClientSessionManager, OutputDevice
13-
from pyatv.protocols.mrp import MrpAudio, MrpMetadata, messages, player_state, protobuf
14+
from pyatv.protocols.mrp import (
15+
MrpAudio,
16+
MrpMetadata,
17+
MrpRemoteControl,
18+
messages,
19+
player_state,
20+
protobuf,
21+
)
22+
from pyatv.protocols.mrp.protobuf import CommandInfo_pb2
1423
from pyatv.settings import InfoSettings
1524

1625
from tests.utils import faketime
@@ -374,3 +383,14 @@ async def test_metadata_position_calculation(metadata, playing_metadata, player_
374383
player_state.playback_state = protobuf.PlaybackState.Playing
375384
playing_metadata["playbackRate"] = 0.0
376385
assert (await metadata.playing()).position == ELAPSED_TIME
386+
387+
388+
async def test_remote_fire_and_forget_command():
389+
protocol = AsyncMock()
390+
psm = Mock(spec=player_state.PlayerStateManager)
391+
remote = MrpRemoteControl(asyncio.get_running_loop(), psm, protocol)
392+
393+
await remote.send_command(CommandInfo_pb2.Play, wait_for_response=False)
394+
395+
protocol.send.assert_awaited_once()
396+
protocol.send_and_receive.assert_not_awaited()

tests/protocols/mrp/test_protocol.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
from pyatv.auth.hap_srp import SRPAuthHandler
1010
from pyatv.conf import ManualService
1111
from pyatv.const import Protocol
12-
from pyatv.protocols.mrp.connection import MrpConnection
12+
from pyatv import exceptions
13+
from pyatv.protocols.mrp import messages, protobuf
14+
from pyatv.protocols.mrp.connection import AbstractMrpConnection, MrpConnection
1315
from pyatv.protocols.mrp.protocol import (
1416
HEARTBEAT_INTERVAL,
1517
HEARTBEAT_RETRIES,
18+
ProtocolState,
1619
MrpProtocol,
1720
heartbeat_loop,
1821
)
@@ -22,6 +25,36 @@
2225
from tests.utils import total_sleep_time, until
2326

2427

28+
class DummyConnection(AbstractMrpConnection):
29+
"""Minimal MRP connection used for protocol unit tests."""
30+
31+
def __init__(self, connected: bool = True):
32+
super().__init__()
33+
self._connected = connected
34+
self.sent = []
35+
self.listener = None
36+
37+
async def connect(self) -> None: # pragma: no cover - not used in tests
38+
self._connected = True
39+
40+
def enable_encryption(self, output_key: bytes, input_key: bytes) -> None:
41+
return
42+
43+
@property
44+
def connected(self) -> bool:
45+
return self._connected
46+
47+
def close(self) -> None:
48+
self._connected = False
49+
50+
def send(self, message: protobuf.ProtocolMessage) -> None:
51+
if not self._connected:
52+
raise exceptions.ConnectionLostError(
53+
"connection is closed; reconnect required"
54+
)
55+
self.sent.append(message)
56+
57+
2558
@pytest_asyncio.fixture
2659
async def mrp_atv():
2760
atv = FakeAppleTV(asyncio.get_running_loop())
@@ -61,3 +94,30 @@ async def test_heartbeat_fail_closes_connection(stub_sleep):
6194
assert total_sleep_time() == HEARTBEAT_INTERVAL
6295

6396
protocol.connection.close.assert_called_once()
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_send_and_receive_fire_and_forget():
101+
connection = DummyConnection()
102+
service = ManualService("mrp_id", Protocol.MRP, 0, {})
103+
protocol = MrpProtocol(connection, SRPAuthHandler(), service, InfoSettings())
104+
protocol._state = ProtocolState.CONNECTED
105+
106+
message = messages.create(protobuf.GENERIC_MESSAGE)
107+
result = await protocol.send_and_receive(
108+
message, wait_for_response=False, timeout=1
109+
)
110+
111+
assert result is None
112+
assert connection.sent == [message]
113+
114+
115+
@pytest.mark.asyncio
116+
async def test_send_raises_when_connection_closed():
117+
connection = DummyConnection(connected=False)
118+
service = ManualService("mrp_id", Protocol.MRP, 0, {})
119+
protocol = MrpProtocol(connection, SRPAuthHandler(), service, InfoSettings())
120+
protocol._state = ProtocolState.CONNECTED
121+
122+
with pytest.raises(exceptions.ConnectionLostError):
123+
await protocol.send(messages.create(protobuf.GENERIC_MESSAGE))

0 commit comments

Comments
 (0)