Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pyatv/core/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ def __init__(self):
"""Initialize a new FacadeRemoteControl instance."""
super().__init__(interface.RemoteControl, DEFAULT_PRIORITIES)

# pylint: disable=invalid-name
@shield.guard
async def send_command(
self,
command,
*,
wait_for_response: bool = True,
timeout: Optional[float] = None,
**kwargs,
) -> None:
"""Send a protocol-specific command (fire-and-forget or with response)."""
return await self.relay("send_command")(
command=command,
wait_for_response=wait_for_response,
timeout=timeout,
**kwargs,
)

# pylint: disable=invalid-name
@shield.guard
async def up(self, action: InputAction = InputAction.SingleTap) -> None:
Expand Down
11 changes: 11 additions & 0 deletions pyatv/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,17 @@ async def menu(self, action: InputAction = InputAction.SingleTap) -> None:
"""Press key menu."""
raise exceptions.NotSupportedError()

async def send_command(
self,
command,
*,
wait_for_response: bool = True,
timeout: Optional[float] = None,
**kwargs,
) -> None:
"""Send a protocol-specific command (fire-and-forget or with response)."""
raise exceptions.NotSupportedError()

@feature(12, "VolumeUp", "Increase volume (deprecated: use Audio.volume_up).")
async def volume_up(self) -> None:
"""Press key volume up.
Expand Down
34 changes: 32 additions & 2 deletions pyatv/protocols/mrp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
_LOGGER = logging.getLogger(__name__)

_DEFAULT_SKIP_TIME = 15
_DEFAULT_COMMAND_TIMEOUT = 5.0

# Source: https://github.com/Daij-Djan/DDHidLib/blob/master/usb_hid_usages.txt
_KEY_LOOKUP = {
Expand Down Expand Up @@ -338,8 +339,24 @@ def __init__(
self.psm = psm
self.protocol = protocol

async def _send_command(self, command, **kwargs):
resp = await self.protocol.send_and_receive(messages.command(command, **kwargs))
async def _send_command(
self,
command,
*,
wait_for_response: bool = True,
timeout: Optional[float] = None,
**kwargs,
):
message = messages.command(command, **kwargs)
if not wait_for_response:
await self.protocol.send(message)
return

resp = await self.protocol.send_and_receive(
message,
timeout=_DEFAULT_COMMAND_TIMEOUT if timeout is None else timeout,
wait_for_response=True,
)
inner = resp.inner()

if inner.sendError == protobuf.SendError.NoError:
Expand All @@ -352,6 +369,19 @@ async def _send_command(self, command, **kwargs):
f"{protobuf.HandlerReturnStatus.Enum.Name(inner.handlerReturnStatus)}"
)

async def send_command(
self,
command,
*,
wait_for_response: bool = True,
timeout: Optional[float] = None,
**kwargs,
) -> None:
"""Send an arbitrary command with optional response handling overrides."""
await self._send_command(
command, wait_for_response=wait_for_response, timeout=timeout, **kwargs
)

async def up(self, action: InputAction = InputAction.SingleTap) -> None:
"""Press key up."""
await _send_hid_key(self.protocol, "up", action)
Expand Down
26 changes: 24 additions & 2 deletions pyatv/protocols/mrp/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,18 @@ def send(self, message: protobuf.ProtocolMessage) -> None:
log_binary(_LOGGER, self._log_str + ">> Send", Encrypted=serialized)

data = write_variant(len(serialized)) + serialized
self._transport.write(data)
if self._transport is None:
raise exceptions.ConnectionLostError(
"connection is closed; reconnect required"
)

try:
self._transport.write(data)
except (ConnectionResetError, BrokenPipeError, OSError) as ex:
self._transport = None
raise exceptions.ConnectionLostError(
"connection was lost while sending; reconnect required"
) from ex
log_protobuf(_LOGGER, self._log_str + ">> Send: Protobuf", message)

def send_raw(self, data):
Expand All @@ -132,7 +143,18 @@ def send_raw(self, data):
log_binary(_LOGGER, self._log_str + ">> Send raw", Encrypted=data)

data = write_variant(len(data)) + data
self._transport.write(data)
if self._transport is None:
raise exceptions.ConnectionLostError(
"connection is closed; reconnect required"
)

try:
self._transport.write(data)
except (ConnectionResetError, BrokenPipeError, OSError) as ex:
self._transport = None
raise exceptions.ConnectionLostError(
"connection was lost while sending; reconnect required"
) from ex

def data_received(self, data):
"""Message was received from device."""
Expand Down
75 changes: 57 additions & 18 deletions pyatv/protocols/mrp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,28 +220,47 @@ async def _enable_encryption(self) -> None:
)
self.connection.enable_encryption(output_key, input_key)

def _ensure_send_possible(self) -> None:
if self._state in [ProtocolState.CONNECTED, ProtocolState.READY]:
if not self.connection.connected:
raise exceptions.ConnectionLostError(
"connection is closed; reconnect required"
)
return

if self._state == ProtocolState.STOPPED:
raise exceptions.ConnectionLostError(
"connection is closed; reconnect required"
)

raise exceptions.InvalidStateError(self._state.name)

async def send(self, message: protobuf.ProtocolMessage) -> None:
"""Send a message and expect no response."""
if self._state not in [
ProtocolState.CONNECTED,
ProtocolState.READY,
]:
raise exceptions.InvalidStateError(self._state.name)
self._ensure_send_possible()

self.connection.send(message)
try:
self.connection.send(message)
except exceptions.ConnectionLostError:
raise
except Exception as ex:
raise exceptions.ConnectionLostError(
"connection was lost while sending; reconnect required"
) from ex

async def send_and_receive(
self,
message: protobuf.ProtocolMessage,
generate_identifier: bool = True,
timeout: float = 5.0,
) -> protobuf.ProtocolMessage:
timeout: Optional[float] = 5.0,
wait_for_response: bool = True,
) -> Optional[protobuf.ProtocolMessage]:
"""Send a message and wait for a response."""
if self._state not in [
ProtocolState.CONNECTED,
ProtocolState.READY,
]:
raise exceptions.InvalidStateError(self._state.name)
self._ensure_send_possible()

if not wait_for_response:
await self.send(message)
return None

# Some messages will respond with the same identifier as used in the
# corresponding request. Others will not and one example is the crypto
Expand All @@ -256,11 +275,18 @@ async def send_and_receive(
else:
identifier = "type_" + str(message.type)

self.connection.send(message)
try:
self.connection.send(message)
except exceptions.ConnectionLostError:
raise
except Exception as ex:
raise exceptions.ConnectionLostError(
"connection was lost while sending; reconnect required"
) from ex
return await self._receive(identifier, timeout)

async def _receive(
self, identifier: str, timeout: float
self, identifier: str, timeout: Optional[float]
) -> protobuf.ProtocolMessage:
semaphore = asyncio.Semaphore(value=0)
self._outstanding[identifier] = OutstandingMessage(
Expand All @@ -269,15 +295,28 @@ async def _receive(

try:
# The connection instance will dispatch the message
async with async_timeout(timeout):
if timeout is None:
await semaphore.acquire()
else:
async with async_timeout(timeout):
await semaphore.acquire()

except asyncio.TimeoutError as ex:
del self._outstanding[identifier]
raise exceptions.OperationTimeoutError(
f"no response received within {timeout} seconds"
) from ex
except Exception:
del self._outstanding[identifier]
raise

response = self._outstanding[identifier].response
del self._outstanding[identifier]
outstanding = self._outstanding.pop(identifier, None)
if outstanding is None:
raise exceptions.ConnectionLostError(
"connection closed before response was received; reconnect required"
)

response = outstanding.response
return response

def message_received(self, message: protobuf.ProtocolMessage, _) -> None:
Expand Down
24 changes: 22 additions & 2 deletions tests/protocols/mrp/test_mrp_interface.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
"""Unit tests for interface implementations in pyatv.protocols.mrp."""

import asyncio
import datetime
import math
from typing import Any, Dict
from unittest.mock import Mock, PropertyMock
from unittest.mock import AsyncMock, Mock, PropertyMock

import pytest

from pyatv import exceptions
from pyatv.core import UpdatedState
from pyatv.interface import ClientSessionManager, OutputDevice
from pyatv.protocols.mrp import MrpAudio, MrpMetadata, messages, player_state, protobuf
from pyatv.protocols.mrp import (
MrpAudio,
MrpMetadata,
MrpRemoteControl,
messages,
player_state,
protobuf,
)
from pyatv.protocols.mrp.protobuf import CommandInfo_pb2
from pyatv.settings import InfoSettings

from tests.utils import faketime
Expand Down Expand Up @@ -374,3 +383,14 @@ async def test_metadata_position_calculation(metadata, playing_metadata, player_
player_state.playback_state = protobuf.PlaybackState.Playing
playing_metadata["playbackRate"] = 0.0
assert (await metadata.playing()).position == ELAPSED_TIME


async def test_remote_fire_and_forget_command():
protocol = AsyncMock()
psm = Mock(spec=player_state.PlayerStateManager)
remote = MrpRemoteControl(asyncio.get_running_loop(), psm, protocol)

await remote.send_command(CommandInfo_pb2.Play, wait_for_response=False)

protocol.send.assert_awaited_once()
protocol.send_and_receive.assert_not_awaited()
62 changes: 61 additions & 1 deletion tests/protocols/mrp/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from pyatv.auth.hap_srp import SRPAuthHandler
from pyatv.conf import ManualService
from pyatv.const import Protocol
from pyatv.protocols.mrp.connection import MrpConnection
from pyatv import exceptions
from pyatv.protocols.mrp import messages, protobuf
from pyatv.protocols.mrp.connection import AbstractMrpConnection, MrpConnection
from pyatv.protocols.mrp.protocol import (
HEARTBEAT_INTERVAL,
HEARTBEAT_RETRIES,
ProtocolState,
MrpProtocol,
heartbeat_loop,
)
Expand All @@ -22,6 +25,36 @@
from tests.utils import total_sleep_time, until


class DummyConnection(AbstractMrpConnection):
"""Minimal MRP connection used for protocol unit tests."""

def __init__(self, connected: bool = True):
super().__init__()
self._connected = connected
self.sent = []
self.listener = None

async def connect(self) -> None: # pragma: no cover - not used in tests
self._connected = True

def enable_encryption(self, output_key: bytes, input_key: bytes) -> None:
return

@property
def connected(self) -> bool:
return self._connected

def close(self) -> None:
self._connected = False

def send(self, message: protobuf.ProtocolMessage) -> None:
if not self._connected:
raise exceptions.ConnectionLostError(
"connection is closed; reconnect required"
)
self.sent.append(message)


@pytest_asyncio.fixture
async def mrp_atv():
atv = FakeAppleTV(asyncio.get_running_loop())
Expand Down Expand Up @@ -61,3 +94,30 @@ async def test_heartbeat_fail_closes_connection(stub_sleep):
assert total_sleep_time() == HEARTBEAT_INTERVAL

protocol.connection.close.assert_called_once()


@pytest.mark.asyncio
async def test_send_and_receive_fire_and_forget():
connection = DummyConnection()
service = ManualService("mrp_id", Protocol.MRP, 0, {})
protocol = MrpProtocol(connection, SRPAuthHandler(), service, InfoSettings())
protocol._state = ProtocolState.CONNECTED

message = messages.create(protobuf.GENERIC_MESSAGE)
result = await protocol.send_and_receive(
message, wait_for_response=False, timeout=1
)

assert result is None
assert connection.sent == [message]


@pytest.mark.asyncio
async def test_send_raises_when_connection_closed():
connection = DummyConnection(connected=False)
service = ManualService("mrp_id", Protocol.MRP, 0, {})
protocol = MrpProtocol(connection, SRPAuthHandler(), service, InfoSettings())
protocol._state = ProtocolState.CONNECTED

with pytest.raises(exceptions.ConnectionLostError):
await protocol.send(messages.create(protobuf.GENERIC_MESSAGE))