Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Added `pybricksdev download` command to download Python scripts to hubs without running them.
Supports BLE, USB, and SSH connections. ([pybricksdev#107])

[pybricksdev#107]: https://github.com/pybricks/pybricksdev/issues/107

## [1.0.1] - 2025-02-20

### Fixed
Expand Down
96 changes: 95 additions & 1 deletion pybricksdev/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,100 @@ def is_pybricks_usb(dev):
await hub.disconnect()


class Download(Tool):
def add_parser(self, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(
"download",
help="upload a Pybricks program without running it",
)
parser.tool = self
parser.add_argument(
"conntype",
metavar="<connection type>",
help="connection type: %(choices)s",
choices=["ble", "usb", "ssh"],
)
parser.add_argument(
"file",
metavar="<file>",
help="path to a MicroPython script or `-` for stdin",
type=argparse.FileType(),
)
parser.add_argument(
"-n",
"--name",
metavar="<name>",
required=False,
help="hostname or IP address for SSH connection; "
"Bluetooth device name or Bluetooth address for BLE connection; "
"serial port name for USB connection",
)

async def run(self, args: argparse.Namespace):
# Pick the right connection
if args.conntype == "ssh":
from pybricksdev.connections.ev3dev import EV3Connection

# So it's an ev3dev
if args.name is None:
print("--name is required for SSH connections", file=sys.stderr)
exit(1)

device_or_address = socket.gethostbyname(args.name)
hub = EV3Connection(device_or_address)
elif args.conntype == "ble":
from pybricksdev.ble import find_device as find_ble
from pybricksdev.connections.pybricks import PybricksHubBLE

# It is a Pybricks Hub with BLE. Device name or address is given.
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
device_or_address = await find_ble(args.name)
hub = PybricksHubBLE(device_or_address)
elif args.conntype == "usb":
from usb.core import find as find_usb

from pybricksdev.connections.pybricks import PybricksHubUSB
from pybricksdev.usb import (
LEGO_USB_VID,
MINDSTORMS_INVENTOR_USB_PID,
SPIKE_ESSENTIAL_USB_PID,
SPIKE_PRIME_USB_PID,
)

def is_pybricks_usb(dev):
return (
(dev.idVendor == LEGO_USB_VID)
and (
dev.idProduct
in [
SPIKE_PRIME_USB_PID,
SPIKE_ESSENTIAL_USB_PID,
MINDSTORMS_INVENTOR_USB_PID,
]
)
and dev.product.endswith("Pybricks")
)

device_or_address = find_usb(custom_match=is_pybricks_usb)

if device_or_address is not None:
hub = PybricksHubUSB(device_or_address)
else:
from pybricksdev.connections.lego import REPLHub

hub = REPLHub()
else:
raise ValueError(f"Unknown connection type: {args.conntype}")

# Connect to the address and upload the script without running it
await hub.connect()
try:
with _get_script_path(args.file) as script_path:
await hub.download(script_path)
finally:
await hub.disconnect()


class Flash(Tool):
def add_parser(self, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(
Expand Down Expand Up @@ -459,7 +553,7 @@ def main():
help="the tool to use",
)

for tool in Compile(), Run(), Flash(), DFU(), OAD(), LWP3(), Udev():
for tool in Compile(), Run(), Download(), Flash(), DFU(), OAD(), LWP3(), Udev():
tool.add_parser(subparsers)

argcomplete.autocomplete(parser)
Expand Down
70 changes: 47 additions & 23 deletions pybricksdev/connections/pybricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,50 @@ async def stop_user_program(self) -> None:
response=True,
)

async def download(self, script_path: str) -> None:
"""
Downloads a script to the hub without running it.

This method handles both compilation and downloading of the script.
For Pybricks hubs, it compiles the script to MPY format and downloads it
using the Pybricks protocol.

Args:
script_path: Path to the Python script to download.

Raises:
RuntimeError: If the hub is not connected or if the hub type is not supported.
ValueError: If the compiled program is too large to fit on the hub.
"""
if self.connection_state_observable.value != ConnectionState.CONNECTED:
raise RuntimeError("not connected")

# since Pybricks profile v1.2.0, the hub will tell us which file format(s) it supports
if not (
self._capability_flags
& (
HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6
| HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6_1_NATIVE
)
):
raise RuntimeError(
"Hub is not compatible with any of the supported file formats"
)

# no support for native modules unless one of the flags below is set
abi = 6

if (
self._capability_flags
& HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6_1_NATIVE
):
abi = (6, 1)

# Compile the script to mpy format
mpy = await compile_multi_file(script_path, abi)
# Download without running
await self.download_user_program(mpy)

async def run(
self,
py_path: Optional[str] = None,
Expand Down Expand Up @@ -506,31 +550,11 @@ async def run(
await self._legacy_run(py_path, wait)
return

# since Pybricks profile v1.2.0, the hub will tell us which file format(s) it supports
if not (
self._capability_flags
& (
HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6
| HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6_1_NATIVE
)
):
raise RuntimeError(
"Hub is not compatible with any of the supported file formats"
)

# no support for native modules unless one of the flags below is set
abi = 6

if (
self._capability_flags
& HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6_1_NATIVE
):
abi = (6, 1)

# Download the program if a path is provided
if py_path is not None:
mpy = await compile_multi_file(py_path, abi)
await self.download_user_program(mpy)
await self.download(py_path)

# Start the program
await self.start_user_program()

if wait:
Expand Down
182 changes: 182 additions & 0 deletions tests/connections/test_pybricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Tests for the pybricks connection module."""

import asyncio
import contextlib
import os
import tempfile
from unittest.mock import AsyncMock, PropertyMock, patch

import pytest
from reactivex.subject import Subject

from pybricksdev.connections.pybricks import (
ConnectionState,
HubCapabilityFlag,
HubKind,
PybricksHubBLE,
StatusFlag,
)


class TestPybricksHub:
"""Tests for the PybricksHub base class functionality."""

@pytest.mark.asyncio
async def test_download_modern_protocol(self):
"""Test downloading with modern protocol and capability flags."""
hub = PybricksHubBLE("mock_device")
hub._mpy_abi_version = 6
hub._client = AsyncMock()
hub.get_capabilities = AsyncMock(return_value={"pybricks": {"mpy": True}})
hub.download_user_program = AsyncMock()
type(hub.connection_state_observable).value = PropertyMock(
return_value=ConnectionState.CONNECTED
)
hub._capability_flags = HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6

with contextlib.ExitStack() as stack:
# Create and manage temporary file
temp = stack.enter_context(
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
)
temp.write("print('test')")
temp_path = temp.name
stack.callback(os.unlink, temp_path)

await hub.download(temp_path)
hub.download_user_program.assert_called_once()

@pytest.mark.asyncio
async def test_download_legacy_firmware(self):
"""Test downloading with legacy firmware."""
hub = PybricksHubBLE("mock_device")
hub._mpy_abi_version = None # Legacy firmware
hub._client = AsyncMock()
hub.download_user_program = AsyncMock()
hub.hub_kind = HubKind.BOOST
type(hub.connection_state_observable).value = PropertyMock(
return_value=ConnectionState.CONNECTED
)
hub._capability_flags = HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6

with contextlib.ExitStack() as stack:
# Create and manage temporary file
temp = stack.enter_context(
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
)
temp.write("print('test')")
temp_path = temp.name
stack.callback(os.unlink, temp_path)

await hub.download(temp_path)
hub.download_user_program.assert_called_once()

@pytest.mark.asyncio
async def test_download_unsupported_capabilities(self):
"""Test downloading when hub doesn't support required capabilities."""
hub = PybricksHubBLE("mock_device")
hub._mpy_abi_version = 6
hub._client = AsyncMock()
hub.get_capabilities = AsyncMock(return_value={"pybricks": {"mpy": False}})
type(hub.connection_state_observable).value = PropertyMock(
return_value=ConnectionState.CONNECTED
)
hub._capability_flags = 0

with contextlib.ExitStack() as stack:
# Create and manage temporary file
temp = stack.enter_context(
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
)
temp.write("print('test')")
temp_path = temp.name
stack.callback(os.unlink, temp_path)

with pytest.raises(
RuntimeError,
match="Hub is not compatible with any of the supported file formats",
):
await hub.download(temp_path)

@pytest.mark.asyncio
async def test_download_compile_error(self):
"""Test handling compilation errors."""
hub = PybricksHubBLE("mock_device")
hub._mpy_abi_version = 6
hub._client = AsyncMock()
hub.get_capabilities = AsyncMock(return_value={"pybricks": {"mpy": True}})
type(hub.connection_state_observable).value = PropertyMock(
return_value=ConnectionState.CONNECTED
)
hub._capability_flags = HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6
hub._max_user_program_size = 1000 # Set a reasonable size limit

with contextlib.ExitStack() as stack:
# Create and manage temporary file
temp = stack.enter_context(
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
)
temp.write("print('test' # Missing closing parenthesis")
temp_path = temp.name
stack.callback(os.unlink, temp_path)

# Mock compile_multi_file to raise SyntaxError
stack.enter_context(
patch(
"pybricksdev.connections.pybricks.compile_multi_file",
side_effect=SyntaxError("invalid syntax"),
)
)

with pytest.raises(SyntaxError, match="invalid syntax"):
await hub.download(temp_path)

@pytest.mark.asyncio
async def test_run_modern_protocol(self):
"""Test running a program with modern protocol."""
hub = PybricksHubBLE("mock_device")
hub._mpy_abi_version = None # Use modern protocol
hub._client = AsyncMock()
hub.client = AsyncMock()
hub.get_capabilities = AsyncMock(return_value={"pybricks": {"mpy": True}})
hub.download_user_program = AsyncMock()
hub.start_user_program = AsyncMock()
hub.write_gatt_char = AsyncMock()
type(hub.connection_state_observable).value = PropertyMock(
return_value=ConnectionState.CONNECTED
)
hub._capability_flags = HubCapabilityFlag.USER_PROG_MULTI_FILE_MPY6
hub.hub_kind = HubKind.BOOST

# Mock the status observable to simulate program start and stop
status_subject = Subject()
hub.status_observable = status_subject
hub._stdout_line_queue = asyncio.Queue()
hub._enable_line_handler = True

with contextlib.ExitStack() as stack:
# Create and manage temporary file
temp = stack.enter_context(
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
)
temp.write("print('test')")
temp_path = temp.name
stack.callback(os.unlink, temp_path)

# Start the run task
run_task = asyncio.create_task(hub.run(temp_path))

# Simulate program start
await asyncio.sleep(0.1)
status_subject.on_next(StatusFlag.USER_PROGRAM_RUNNING)

# Simulate program stop after a short delay
await asyncio.sleep(0.1)
status_subject.on_next(0) # Clear all flags

# Wait for run task to complete
await run_task

# Verify the expected calls were made
hub.download_user_program.assert_called_once()
hub.start_user_program.assert_called_once()
Loading