diff --git a/src/lerobot/motors/dc_motors_controller.py b/src/lerobot/motors/dc_motors_controller.py new file mode 100644 index 0000000000..702d0b54fc --- /dev/null +++ b/src/lerobot/motors/dc_motors_controller.py @@ -0,0 +1,322 @@ +import abc +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Protocol, TypeAlias + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +NameOrID: TypeAlias = str | int + +logger = logging.getLogger(__name__) + +class MotorNormMode(str, Enum): + PWM_DUTY_CYCLE = "pwm_duty_cycle" # 0 to 1 for PWM control + +@dataclass +class DCMotor: + id: int + model: str + norm_mode: MotorNormMode + protocol: str = "pwm" # pwm, i2c, can, serial + +class ProtocolHandler(Protocol): + """Protocol for different DC motor communication methods.""" + + def connect(self) -> None: + """Connect to the motor controller.""" + ... + + def disconnect(self) -> None: + """Disconnect from the motor controller.""" + ... + + def set_position(self, motor_id: int, position: float) -> None: + """Set motor position (0 to 1).""" + ... + + def set_velocity(self, motor_id: int, velocity: float, instant: bool = True) -> None: + """Set motor velocity (normalized -1 to 1).""" + ... + + def update_velocity(self, motor_id: int, max_step: float = 1.0) -> None: + """Update motor velocity.""" + ... + + def get_position(self, motor_id: int) -> float | None: + """Get current motor position if encoder available.""" + ... + + def get_velocity(self, motor_id: int) -> float: + """Get current motor velocity.""" + ... + + def get_pwm(self, motor_id: int) -> float: + """Get current PWM duty cycle.""" + ... + + def set_pwm(self, motor_id: int, duty_cycle: float) -> None: + """Set PWM duty cycle (0 to 1).""" + ... + + def enable_motor(self, motor_id: int) -> None: + """Enable motor.""" + ... + + def disable_motor(self, motor_id: int) -> None: + """Disable motor.""" + ... +class BaseDCMotorsController(abc.ABC): + """ + Abstract base class for DC motor controllers. + + Concrete implementations should inherit from this class and implement + the abstract methods for their specific protocol. + """ + + def __init__( + self, + config: dict | None = None, + motors: dict[str, DCMotor] | None = None, + protocol: str = "pwm", + ): + self.config = config or {} + self.motors = motors or {} + self.protocol = protocol + + self._id_to_name_dict = {m.id: motor for motor, m in self.motors.items()} + self._name_to_id_dict = {motor: m.id for motor, m in self.motors.items()} + + self.protocol_handler: ProtocolHandler | None = None + self._is_connected = False + + self._validate_motors() + + def __len__(self): + return len(self.motors) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Config: {self.config}\n" + f" Motors: {list(self.motors.keys())}\n" + f" Protocol: '{self.protocol}'\n" + ")" + ) + + def _validate_motors(self) -> None: + """Validate motor configuration.""" + if not self.motors: + raise ValueError("At least one motor must be specified.") + + # Check for duplicate IDs + ids = [m.id for m in self.motors.values()] + if len(ids) != len(set(ids)): + raise ValueError("Motor IDs must be unique.") + + @property + def is_connected(self) -> bool: + return self._is_connected + + def _get_motor_id(self, motor: NameOrID) -> int: + """Get motor ID from name or ID.""" + if isinstance(motor, int): + return motor + elif isinstance(motor, str): + if motor in self._name_to_id_dict: + return self._name_to_id_dict[motor] + else: + raise ValueError(f"Motor '{motor}' not found.") + else: + raise TypeError(f"Motor must be string or int, got {type(motor)}") + + @abc.abstractmethod + def _create_protocol_handler(self) -> ProtocolHandler: + """Create the appropriate protocol handler based on configuration.""" + pass + + ############################################################################################################################## + # Connection + ############################################################################################################################## + + def connect(self) -> None: + """Connect to the motor controller.""" + if self._is_connected: + logger.info(f"{self} is already connected.") + return + + self.protocol_handler = self._create_protocol_handler() + self.protocol_handler.connect() + self._is_connected = True + logger.info(f"{self} connected successfully.") + + def disconnect(self) -> None: + """Disconnect from the motor controller.""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + if self.protocol_handler: + self.protocol_handler.disconnect() + + self._is_connected = False + logger.info(f"{self} disconnected.") + + ############################################################################################################################## + # Position Functions + ############################################################################################################################## + + def get_position(self, motor: NameOrID) -> float | None: + """Get current motor position if encoder available.""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return None + + motor_id = self._get_motor_id(motor) + return self.protocol_handler.get_position(motor_id) + + def set_position(self, motor: NameOrID, position: float) -> None: + """Set motor position (0 to 1).""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return None + + motor_id = self._get_motor_id(motor) + self.protocol_handler.set_position(motor_id, position) + + ############################################################################################################################## + # Velocity Functions + ############################################################################################################################## + + def get_velocity(self, motor: NameOrID) -> float: + """Get current motor velocity.""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return None + + motor_id = self._get_motor_id(motor) + return self.protocol_handler.get_velocity(motor_id) + + def get_velocities(self) -> dict[NameOrID, float]: + """Get current motor velocities.""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return { } + + return {motor: self.get_velocity(motor) for motor in self.motors.keys()} + + def set_velocity(self, motor: NameOrID, velocity: float, normalize: bool = True,) -> None: + """ + Set motor velocity with ramp-up. + + Args: + motor: Motor name or ID + velocity: Target velocity (-1 to 1 if normalized, otherwise in RPM) + normalize: Whether to normalize the velocity + """ + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + motor_id = self._get_motor_id(motor) + if normalize: + velocity = max(-1.0, min(1.0, velocity)) # Clamp to [-1, 1] + + self.protocol_handler.set_velocity(motor_id, velocity) + logger.debug(f"Set motor {motor} velocity to {velocity}") + + def set_velocities(self, motors: dict[NameOrID, float], normalize: bool = True) -> None: + if not self._is_connected: + return + + """ + Set motor velocities. + + Args: + motors: Dictionary of motor names or IDs and target velocities + normalize: Whether to normalize the velocity + """ + for motor, velocity in motors.items(): + self.set_velocity(motor, velocity, normalize) + + ############################################################################################################################## + # Update velocity + ############################################################################################################################## + + def update_velocity(self, motor: NameOrID | None = None, max_step: float = 1.0) -> None: + """Update motor velocity.""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + if motor is None: + for motor_id in self._id_to_name_dict.keys(): + self.protocol_handler.update_velocity(motor_id, max_step) + else: + motor_id = self._get_motor_id(motor) + self.protocol_handler.update_velocity(motor_id, max_step) + + ############################################################################################################################## + # PWM Functions + ############################################################################################################################## + + def get_pwm(self, motor: NameOrID) -> float: + """Get current PWM duty cycle.""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + motor_id = self._get_motor_id(motor) + return self.protocol_handler.get_pwm(motor_id) + + def set_pwm(self, motor: NameOrID, duty_cycle: float) -> None: + """ + Set PWM duty cycle. + + Args: + motor: Motor name or ID + duty_cycle: PWM duty cycle (0 to 1) + """ + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + motor_id = self._get_motor_id(motor) + + # Clamp to [0, 1] + duty_cycle = max(0.0, min(1.0, duty_cycle)) + + self.protocol_handler.set_pwm(motor_id, duty_cycle) + logger.debug(f"Set motor {motor} PWM to {duty_cycle}") + + ############################################################################################################################## + # Enable/Disable Functions + ############################################################################################################################## + + def enable_motor(self, motor: NameOrID | None = None) -> None: + """Enable motor(s).""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + if motor is None: + # Enable all motors + for motor_id in self._id_to_name_dict.keys(): + self.protocol_handler.enable_motor(motor_id) + else: + motor_id = self._get_motor_id(motor) + self.protocol_handler.enable_motor(motor_id) + + def disable_motor(self, motor: NameOrID | None = None) -> None: + """Disable motor(s).""" + if not self._is_connected: + logger.info(f"{self} is not connected.") + return + + if motor is None: + # Disable all motors + for motor_id in self._id_to_name_dict.keys(): + self.protocol_handler.disable_motor(motor_id) + else: + motor_id = self._get_motor_id(motor) + self.protocol_handler.disable_motor(motor_id) diff --git a/src/lerobot/motors/dc_pwm/__init__.py b/src/lerobot/motors/dc_pwm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/lerobot/motors/dc_pwm/dc_pwm.py b/src/lerobot/motors/dc_pwm/dc_pwm.py new file mode 100644 index 0000000000..29e2dc4dbb --- /dev/null +++ b/src/lerobot/motors/dc_pwm/dc_pwm.py @@ -0,0 +1,437 @@ +import logging +from typing import Dict, Optional, List + +from lerobot.motors.dc_motors_controller import BaseDCMotorsController, DCMotor, ProtocolHandler + +logger = logging.getLogger(__name__) + + +# Pi 5 Hardware PWM Configuration +PI5_HARDWARE_PWM_PINS = { + "pwm0": [12], # PWM0 channels + "pwm1": [13], # PWM1 channels + "pwm2": [18], # PWM2 channels + "pwm3": [19], # PWM3 channels +} + +# Pi 5 All Available GPIO Pins (40-pin header) +PI5_ALL_GPIO_PINS = [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41 +] + +# Pi 5 Optimal Settings for DRV8874PWPR +PI5_OPTIMAL_FREQUENCY = 25000 # 25kHz - more compatible with gpiozero +PI5_MAX_FREQUENCY = 25000 # 25kHz - Pi 5 can handle higher frequencies +PI5_RESOLUTION = 12 # 12-bit resolution + +class PWMDCMotorsController(BaseDCMotorsController): + """PWM-based DC motor controller optimized for DRV8874PWPR H-bridge drivers.""" + + def __init__(self, config: dict | None = None, motors: dict[str, DCMotor] | None = None, protocol: str = "pwm"): + super().__init__(config, motors, protocol) + + def _create_protocol_handler(self) -> ProtocolHandler: + return PWMProtocolHandler(self.config, self.motors) + + +class PWMProtocolHandler(ProtocolHandler): + """ + PWM protocol handler optimized for DRV8874PWPR H-bridge motor drivers. + + DRV8874PWPR Features: + - IN1: PWM speed control (hardware PWM recommended) + - IN2: Direction control (regular GPIO) + - Built-in current limiting and thermal protection + - 25kHz PWM frequency optimal + + Configuration: + - in1_pins: IN1 pins (PWM speed control) + - in2_pins: IN2 pins (direction control) + - pwm_frequency: 25000Hz (optimal for DRV8874PWPR) + """ + + ############################################################################################################################## + # Configuration + ############################################################################################################################## + + def __init__(self, config: Dict, motors: Dict[str, DCMotor]): + self.config = config + self.in1_pins = config.get("in1_pins", []) + self.in2_pins = config.get("in2_pins", []) + self.enable_pins = config.get("enable_pins", []) + self.brake_pins = config.get("brake_pins", []) + self.pwm_frequency = config.get("pwm_frequency", PI5_OPTIMAL_FREQUENCY) + self.invert_direction = config.get("invert_direction", False) + self.invert_enable = config.get("invert_enable", False) + self.invert_brake = config.get("invert_brake", False) + + # Motor configuration and state tracking + self.motors: Dict[str, DCMotor] = motors + self.motor_states: Dict[int, Dict] = {} + self.in1_channels = {} + self.in2_channels = {} + self.enable_channels = {} + self.brake_channels = {} + + # Validate Pi 5 pins + self._validate_pi5_pins() + + # Import gpiozero + self._import_gpiozero() + + def _validate_pi5_pins(self): + """Validate that pins are valid GPIO pins on Pi 5.""" + all_hardware_pwm = [] + for pwm_pins in PI5_HARDWARE_PWM_PINS.values(): + all_hardware_pwm.extend(pwm_pins) + + # Validate IN1 pins (should be hardware PWM for best performance) + invalid_in1_pins = [pin for pin in self.in1_pins if pin not in all_hardware_pwm] + if invalid_in1_pins: + # logger.warning( + # f"IN1 pins {invalid_in1_pins} are not hardware PWM pins on Pi 5. " + # f"Hardware PWM pins: {all_hardware_pwm}" + # ) + pass + + # Validate IN2 pins (can be any GPIO) + invalid_in2_pins = [pin for pin in self.in2_pins if pin not in PI5_ALL_GPIO_PINS] + if invalid_in2_pins: + logger.warning( + f"IN2 pins {invalid_in2_pins} are not valid GPIO pins on Pi 5. " + f"Valid GPIO pins: {PI5_ALL_GPIO_PINS}" + ) + + # Check for pin conflicts + all_used_pins = set(self.in1_pins + self.in2_pins + self.enable_pins + self.brake_pins) + if len(all_used_pins) != len(self.in1_pins + self.in2_pins + self.enable_pins + self.brake_pins): + logger.warning("Duplicate pins detected in configuration") + + # Validate motor count + motor_count = len(self.in1_pins) + logger.info(f"Configuring {motor_count} DRV8874PWPR motors with gpiozero") + + def _import_gpiozero(self): + """Import gpiozero.""" + try: + import gpiozero + self.gpiozero = gpiozero + logger.info("Using gpiozero for DRV8874PWPR motor control") + + except ImportError: + raise ImportError( + "gpiozero not available. Install with: uv pip install gpiozero>=2.0" + ) + + def _setup_pwmled(self, pin: int, label: str) -> 'gpiozero.PWMLED': # type: ignore + """Safely set up a PWMLED on the given pin, with fallback to default frequency.""" + try: + return self.gpiozero.PWMLED(pin, frequency=self.pwm_frequency) + except Exception as e: + logger.warning(f"{label}: Failed with frequency {self.pwm_frequency}, retrying with default. ({e})") + try: + return self.gpiozero.PWMLED(pin) + except Exception as e2: + logger.error(f"{label}: Failed to setup PWMLED on pin {pin}: {e2}") + raise + + def _safe_close(self, channel: 'gpiozero.PWMLED', label: str) -> None: # type: ignore + """Safely close a PWMLED channel.""" + try: + channel.close() + except Exception as e: + logger.warning(f"Error closing {label}: {e}") + + ############################################################################################################################## + # Connection + ############################################################################################################################## + + def connect(self) -> None: + """Initialize gpiozero for DRV8874PWPR motor drivers with symmetric PWM on IN1 and IN2.""" + try: + for motor_id, (in1_pin, in2_pin) in enumerate(zip(self.in1_pins, self.in2_pins), start=1): + self.motor_states[motor_id] = { + "position": 0.0, + "velocity": 0.0, + "pwm": 0.0, + "enabled": False, + "brake_active": False, + "direction": 1 + } + + in1 = self._setup_pwmled(in1_pin, f"Motor {motor_id} IN1") + in1.off() + self.in1_channels[motor_id] = in1 + logger.debug(f"Motor {motor_id} IN1 setup on pin {in1_pin}") + + in2 = self._setup_pwmled(in2_pin, f"Motor {motor_id} IN2") + in2.off() + self.in2_channels[motor_id] = in2 + logger.debug(f"Motor {motor_id} IN2 setup on pin {in2_pin}") + + total_pins = len(self.in1_pins) + len(self.in2_pins) + logger.info(f"DRV8874PWPR setup complete: {len(self.in1_pins)} motors, {total_pins} GPIOs used") + logger.info(f"PWM frequency: {self.pwm_frequency} Hz") + except Exception as e: + logger.error(f"Motor driver setup failed: {e}") + raise RuntimeError("gpiozero hardware not available") + + def disconnect(self) -> None: + """Clean up gpiozero PWMLED channels for IN1 and IN2.""" + for motor_id, channel in self.in1_channels.items(): + self._safe_close(channel, f"IN1 (motor {motor_id})") + + for motor_id, channel in self.in2_channels.items(): + self._safe_close(channel, f"IN2 (motor {motor_id})") + + logger.info("DRV8874PWPR motor driver disconnected") + + ############################################################################################################################## + # Position Functions + ############################################################################################################################## + + def get_position(self, motor_id: int) -> Optional[float]: + """Get current motor position if encoder available.""" + return self.motor_states.get(motor_id, {}).get("position", 0.0) + + def set_position(self, motor_id: int, position: float) -> None: + """ + Set motor position (0 to 1). + Note: This is a simplified implementation. For precise position control, + you'd need encoders and PID control. + """ + if position < 0: + position = 0 + elif position > 1: + position = 1 + + self.motor_states[motor_id]["position"] = position + + # Convert position to PWM (simple linear mapping) + pwm_duty = position + self.set_pwm(motor_id, pwm_duty) + + ############################################################################################################################## + # Velocity Functions + ############################################################################################################################## + + def get_velocity(self, motor_id: int) -> float: + """Get current motor velocity.""" + return self.motor_states.get(motor_id, {}).get("velocity", 0.0) + + def set_velocity(self, motor_id: int, target_velocity: float, instant: bool = True) -> None: + """ + Set the target velocity for the motor (-1.0 to 1.0). + Actual velocity will be slewed toward this value in update_velocity(). + """ + target_velocity = max(-1.0, min(1.0, target_velocity)) # clamp + self.motor_states[motor_id]["target_velocity"] = target_velocity + + if instant: + self.update_velocity(motor_id, 1.0) + + def update_velocity(self, motor_id: int, max_step: float = 1.0) -> None: + """ + Gradually update the motor velocity toward its target using a slew-rate limiter. + Call this periodically (e.g., every 10–20 ms). + """ + state = self.motor_states[motor_id] + current = state.get("velocity", 0.0) + target = state.get("target_velocity", 0.0) + + # Apply slew-rate limit + if target > current: + new_velocity = min(current + max_step, target) + elif target < current: + new_velocity = max(current - max_step, target) + else: + new_velocity = target + + # Save new velocity + state["velocity"] = new_velocity + + # Convert to PWM duty cycle + pwm_duty = self._velocity_to_pwm(new_velocity) + state["pwm"] = pwm_duty + state["brake_active"] = False + + in1 = self.in1_channels.get(motor_id) + in2 = self.in2_channels.get(motor_id) + + if new_velocity > 0: + if in1: in1.value = pwm_duty + if in2: in2.off() + state["direction"] = 1 + + elif new_velocity < 0: + if in1: in1.off() + if in2: in2.value = pwm_duty + state["direction"] = -1 + + else: + if in1: in1.off() + if in2: in2.off() + state["direction"] = 0 + + ############################################################################################################################## + # PWM Functsions + ############################################################################################################################## + + def get_pwm(self, motor_id: int) -> float: + """Get current PWM duty cycle.""" + return self.motor_states.get(motor_id, {}).get("pwm", 0.0) + + def set_pwm(self, motor_id: int, duty_cycle: float) -> None: + """ + Set PWM duty cycle (0..0.98) respecting current direction. + Uses symmetric PWM: IN1 for forward, IN2 for reverse. + """ + + # Cap your duty to 0.98 (to avoid DRV8871's fixed off-time weirdness at 100%) + duty_cycle = max(0.0, min(0.98, duty_cycle)) + self.motor_states[motor_id]["pwm"] = duty_cycle + + direction = self.motor_states[motor_id].get("direction", 1) + in1 = self.in1_channels.get(motor_id) + in2 = self.in2_channels.get(motor_id) + + if not in1 or not in2: + logger.warning(f"Motor {motor_id} missing IN1/IN2 channel(s)") + return + + try: + if direction > 0: # forward + in1.value = duty_cycle + in2.off() + elif direction < 0: # reverse + in1.off() + in2.value = duty_cycle + else: + in1.off() + in2.off() + logger.debug(f"Motor {motor_id} PWM={duty_cycle:.3f} dir={'FWD' if direction>0 else 'REV' if direction<0 else 'STOP'}") + except Exception as e: + logger.warning(f"Error setting PWM for motor {motor_id}: {e}") + + ############################################################################################################################## + # Enable/Disable Functions + ############################################################################################################################## + + def enable_motor(self, motor_id: int) -> None: + """Enable motor.""" + self.motor_states[motor_id]["enabled"] = True + logger.debug(f"Motor {motor_id} enabled") + + def disable_motor(self, motor_id: int) -> None: + """Disable motor by setting PWM to 0.""" + self.set_pwm(motor_id, 0.0) + self.motor_states[motor_id]["enabled"] = False + logger.debug(f"Motor {motor_id} disabled") + + ############################################################################################################################## + # Helper methods for DRV8874PWPR-specific functionality + ############################################################################################################################## + + def _get_direction(self, motor_id: int) -> bool: + """Get motor direction.""" + if motor_id not in self.in2_channels: + return False + return self.in2_channels[motor_id].value == 1 + + def _set_direction(self, motor_id: int, forward: bool) -> None: + """ + Set motor direction for DRV8874PWPR. + This method updates the direction state and applies appropriate PWM. + """ + if motor_id not in self.in2_channels: + return + + # Apply direction inversion if configured + if self.invert_direction: + forward = not forward + + # Update direction state + self.motor_states[motor_id]["direction"] = 1 if forward else -1 + + try: + # Set IN2 for direction control + self.in2_channels[motor_id].on() if not forward else self.in2_channels[motor_id].off() + + # Re-apply current PWM with new direction + current_pwm = self.motor_states[motor_id].get("pwm", 0.0) + if current_pwm > 0: + self.set_pwm(motor_id, current_pwm) + + logger.debug(f"Motor {motor_id} direction set to {'forward' if forward else 'backward'}") + except Exception as e: + logger.warning(f"Error setting direction for motor {motor_id}: {e}") + + # DRV8874PWPR-specific convenience methods + def activate_brake(self, motor_id: int) -> None: + """ + Activate motor brake for DRV8874PWPR. + Brake mode: IN1 = HIGH, IN2 = HIGH + """ + in1 = self.in1_channels.get(motor_id) + in2 = self.in2_channels.get(motor_id) + + if not in1 or not in2: + logger.warning(f"Cannot activate brake: IN1 or IN2 not found for motor {motor_id}") + return + + try: + in1.on() + in2.on() + self.motor_states[motor_id]["brake_active"] = True + logger.debug(f"Motor {motor_id} brake activated (IN1=1, IN2=1)") + except Exception as e: + logger.warning(f"Error activating brake for motor {motor_id}: {e}") + + def release_brake(self, motor_id: int) -> None: + """ + Release motor brake for DRV8874PWPR. + Coast mode: IN1 = LOW, IN2 = LOW + """ + in1 = self.in1_channels.get(motor_id) + in2 = self.in2_channels.get(motor_id) + + if not in1 or not in2: + logger.warning(f"Cannot release brake: IN1 or IN2 not found for motor {motor_id}") + return + + try: + in1.off() + in2.off() + self.motor_states[motor_id]["brake_active"] = False + logger.debug(f"Motor {motor_id} brake released (IN1=0, IN2=0)") + except Exception as e: + logger.warning(f"Error releasing brake for motor {motor_id}: {e}") + + ############################################################################################################################## + # Helper methods for DRV8874PWPR-specific functionality + ############################################################################################################################## + + def _velocity_to_pwm(self, velocity: float) -> float: + """ + Convert normalized velocity (-1 to 1) into PWM duty cycle (0.0 to 1.0). + Tuned so that velocity=0.5 gives about 0.25 duty (half speed in real world). + """ + + # This code works for the 30RPM 10kg.cm torque motor. + v = abs(velocity) + + # Special case: stop = true 0 duty + if v == 0: + return 0.0 + + # Deadzone threshold (motor won’t spin below this duty) + deadzone = 0.1 + + # Exponent > 1 pushes mid values lower + exponent = 2.0 # quadratic curve, makes 0.5 input ≈ 0.25 duty + + # Map velocity into [deadzone, 1.0] + pwm = deadzone + (1 - deadzone) * (v ** exponent) + + return pwm diff --git a/tests/motors/test_dc_pwm.py b/tests/motors/test_dc_pwm.py new file mode 100644 index 0000000000..4ba2cd626d --- /dev/null +++ b/tests/motors/test_dc_pwm.py @@ -0,0 +1,516 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from collections.abc import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from lerobot.motors.dc_motors_controller import DCMotor, MotorNormMode +from lerobot.motors.dc_pwm.dc_pwm import ( + PI5_OPTIMAL_FREQUENCY, + PWMDCMotorsController, + PWMProtocolHandler, +) + + +class MockPWMLED: + """Mock PWMLED for testing without hardware.""" + + def __init__(self, pin: int, frequency: int | None = None): + self.pin = pin + self.frequency = frequency + self._value = 0.0 + self.is_closed = False + + @property + def value(self) -> float: + return self._value + + @value.setter + def value(self, val: float): + self._value = max(0.0, min(1.0, val)) + + def on(self): + self._value = 1.0 + + def off(self): + self._value = 0.0 + + def close(self): + self.is_closed = True + + +@pytest.fixture +def mock_gpiozero(): + """Mock gpiozero module.""" + mock_gpio = MagicMock() + mock_gpio.PWMLED = MockPWMLED + return mock_gpio + + +@pytest.fixture +def dummy_motors() -> dict[str, DCMotor]: + return { + "motor_1": DCMotor(id=1, model="drv8874", norm_mode=MotorNormMode.PWM_DUTY_CYCLE), + "motor_2": DCMotor(id=2, model="drv8874", norm_mode=MotorNormMode.PWM_DUTY_CYCLE), + "motor_3": DCMotor(id=3, model="drv8874", norm_mode=MotorNormMode.PWM_DUTY_CYCLE), + } + + +@pytest.fixture +def pwm_config() -> dict: + return { + "in1_pins": [12, 13, 18], + "in2_pins": [19, 20, 21], + "pwm_frequency": PI5_OPTIMAL_FREQUENCY, + } + + +@pytest.fixture +def protocol_handler(pwm_config, dummy_motors, mock_gpiozero): + """Create a PWMProtocolHandler with mocked gpiozero.""" + # Patch sys.modules to intercept the import inside _import_gpiozero + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + handler = PWMProtocolHandler(pwm_config, dummy_motors) + return handler + + +@pytest.fixture(autouse=True) +def auto_patch_gpiozero(mock_gpiozero): + """Automatically patch gpiozero for all tests.""" + # This fixture patches sys.modules to intercept imports + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + yield + + +def test_controller_instantiation(pwm_config, dummy_motors): + """Test PWMDCMotorsController can be instantiated.""" + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors, protocol="pwm") + assert controller.config == pwm_config + assert controller.motors == dummy_motors + assert controller.protocol == "pwm" + + +def test_controller_creates_protocol_handler(pwm_config, dummy_motors, mock_gpiozero): + """Test that controller creates the correct protocol handler.""" + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors) + handler = controller._create_protocol_handler() + assert isinstance(handler, PWMProtocolHandler) + + +@pytest.mark.parametrize( + "in1_pins, in2_pins, expected_count", + [ + ([12], [19], 1), + ([12, 13], [19, 20], 2), + ([12, 13, 18], [19, 20, 21], 3), + ], +) +def test_protocol_handler_init(in1_pins, in2_pins, expected_count, dummy_motors, mock_gpiozero): + """Test PWMProtocolHandler initialization.""" + config = { + "in1_pins": in1_pins, + "in2_pins": in2_pins, + "pwm_frequency": PI5_OPTIMAL_FREQUENCY, + } + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + handler = PWMProtocolHandler(config, dummy_motors) + assert len(handler.in1_pins) == expected_count + assert len(handler.in2_pins) == expected_count + assert handler.pwm_frequency == PI5_OPTIMAL_FREQUENCY + + +def test_connect(protocol_handler): + """Test connect method sets up PWMLED channels.""" + protocol_handler.connect() + + assert len(protocol_handler.in1_channels) == 3 + assert len(protocol_handler.in2_channels) == 3 + assert all(isinstance(ch, MockPWMLED) for ch in protocol_handler.in1_channels.values()) + assert all(isinstance(ch, MockPWMLED) for ch in protocol_handler.in2_channels.values()) + + # Check that channels were initialized with correct pins + assert protocol_handler.in1_channels[1].pin == 12 + assert protocol_handler.in2_channels[1].pin == 19 + + +def test_disconnect(protocol_handler): + """Test disconnect method closes all channels.""" + protocol_handler.connect() + protocol_handler.disconnect() + + assert all(ch.is_closed for ch in protocol_handler.in1_channels.values()) + assert all(ch.is_closed for ch in protocol_handler.in2_channels.values()) + + +def test_get_position(protocol_handler): + """Test get_position returns stored position.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["position"] = 0.5 + + position = protocol_handler.get_position(1) + assert position == 0.5 + + # Test with uninitialized motor + position = protocol_handler.get_position(999) + assert position == 0.0 + + +def test_set_position(protocol_handler): + """Test set_position updates position and PWM.""" + protocol_handler.connect() + + # Set direction to forward first + protocol_handler.motor_states[1]["direction"] = 1 + + protocol_handler.set_position(1, 0.75) + assert protocol_handler.motor_states[1]["position"] == 0.75 + assert protocol_handler.motor_states[1]["pwm"] == 0.75 + + # Test clamping + protocol_handler.set_position(1, 1.5) + assert protocol_handler.motor_states[1]["position"] == 1.0 + + protocol_handler.set_position(1, -0.5) + assert protocol_handler.motor_states[1]["position"] == 0.0 + + +def test_get_velocity(protocol_handler): + """Test get_velocity returns stored velocity.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["velocity"] = 0.5 + + velocity = protocol_handler.get_velocity(1) + assert velocity == 0.5 + + # Test with uninitialized motor + velocity = protocol_handler.get_velocity(999) + assert velocity == 0.0 + + +def test_set_velocity(protocol_handler): + """Test set_velocity sets target velocity.""" + protocol_handler.connect() + + protocol_handler.set_velocity(1, 0.5, instant=True) + assert protocol_handler.motor_states[1]["target_velocity"] == 0.5 + assert protocol_handler.motor_states[1]["velocity"] == 0.5 + + # Test clamping + protocol_handler.set_velocity(1, 1.5, instant=False) + assert protocol_handler.motor_states[1]["target_velocity"] == 1.0 + + protocol_handler.set_velocity(1, -1.5, instant=False) + assert protocol_handler.motor_states[1]["target_velocity"] == -1.0 + + +@pytest.mark.parametrize( + "current, target, max_step, expected", + [ + (0.0, 1.0, 0.5, 0.5), # Ramp up + (1.0, 0.0, 0.5, 0.5), # Ramp down + (0.0, 0.5, 1.0, 0.5), # Instant update + (0.0, 0.0, 1.0, 0.0), # No change + ], +) +def test_update_velocity(current, target, max_step, expected, protocol_handler): + """Test update_velocity with slew rate limiting.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["velocity"] = current + protocol_handler.motor_states[1]["target_velocity"] = target + + protocol_handler.update_velocity(1, max_step) + + assert protocol_handler.motor_states[1]["velocity"] == expected + + +def test_update_velocity_forward(protocol_handler): + """Test update_velocity sets correct GPIO states for forward motion.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["target_velocity"] = 0.5 + protocol_handler.motor_states[1]["velocity"] = 0.0 + + protocol_handler.update_velocity(1, 1.0) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value > 0 # PWM on + assert in2.value == 0.0 # OFF + assert protocol_handler.motor_states[1]["direction"] == 1 + + +def test_update_velocity_reverse(protocol_handler): + """Test update_velocity sets correct GPIO states for reverse motion.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["target_velocity"] = -0.5 + protocol_handler.motor_states[1]["velocity"] = 0.0 + + protocol_handler.update_velocity(1, 1.0) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 0.0 # OFF + assert in2.value > 0 # PWM on + assert protocol_handler.motor_states[1]["direction"] == -1 + + +def test_update_velocity_stop(protocol_handler): + """Test update_velocity stops motor when velocity is zero.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["target_velocity"] = 0.0 + protocol_handler.motor_states[1]["velocity"] = 0.5 + + protocol_handler.update_velocity(1, 1.0) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 0.0 + assert in2.value == 0.0 + assert protocol_handler.motor_states[1]["direction"] == 0 + + +def test_get_pwm(protocol_handler): + """Test get_pwm returns stored PWM value.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["pwm"] = 0.75 + + pwm = protocol_handler.get_pwm(1) + assert pwm == 0.75 + + # Test with uninitialized motor + pwm = protocol_handler.get_pwm(999) + assert pwm == 0.0 + + +def test_set_pwm_forward(protocol_handler): + """Test set_pwm with forward direction.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["direction"] = 1 + + protocol_handler.set_pwm(1, 0.5) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 0.5 # PWM value set to duty_cycle + assert in2.value == 0.0 # OFF + assert protocol_handler.motor_states[1]["pwm"] == 0.5 + + +def test_set_pwm_reverse(protocol_handler): + """Test set_pwm with reverse direction.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["direction"] = -1 + + protocol_handler.set_pwm(1, 0.5) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 0.0 # OFF + assert in2.value == 0.5 # PWM value set to duty_cycle + assert protocol_handler.motor_states[1]["pwm"] == 0.5 + + +def test_set_pwm_stop(protocol_handler): + """Test set_pwm stops motor when direction is 0.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["direction"] = 0 + + protocol_handler.set_pwm(1, 0.5) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 0.0 + assert in2.value == 0.0 + + +def test_set_pwm_clamping(protocol_handler): + """Test set_pwm clamps duty cycle to 0.98.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["direction"] = 1 + + protocol_handler.set_pwm(1, 1.0) + assert protocol_handler.motor_states[1]["pwm"] == 0.98 + + protocol_handler.set_pwm(1, -0.1) + assert protocol_handler.motor_states[1]["pwm"] == 0.0 + + +def test_enable_motor(protocol_handler): + """Test enable_motor sets enabled flag.""" + protocol_handler.connect() + + protocol_handler.enable_motor(1) + assert protocol_handler.motor_states[1]["enabled"] is True + + +def test_disable_motor(protocol_handler): + """Test disable_motor sets PWM to 0 and disables.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["pwm"] = 0.5 + protocol_handler.motor_states[1]["direction"] = 0 # Set to stop/neutral + + protocol_handler.disable_motor(1) + + assert protocol_handler.motor_states[1]["pwm"] == 0.0 + assert protocol_handler.motor_states[1]["enabled"] is False + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + # When direction=0, both channels are turned off + assert in1.value == 0.0 + assert in2.value == 0.0 + + +def test_activate_brake(protocol_handler): + """Test activate_brake sets both IN1 and IN2 high.""" + protocol_handler.connect() + + protocol_handler.activate_brake(1) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 1.0 + assert in2.value == 1.0 + assert protocol_handler.motor_states[1]["brake_active"] is True + + +def test_release_brake(protocol_handler): + """Test release_brake sets both IN1 and IN2 low.""" + protocol_handler.connect() + protocol_handler.motor_states[1]["brake_active"] = True + + protocol_handler.release_brake(1) + + in1 = protocol_handler.in1_channels[1] + in2 = protocol_handler.in2_channels[1] + assert in1.value == 0.0 + assert in2.value == 0.0 + assert protocol_handler.motor_states[1]["brake_active"] is False + + +@pytest.mark.parametrize( + "velocity, expected_pwm", + [ + (0.0, 0.0), + (0.5, 0.325), # deadzone + (1-deadzone) * 0.5^2 = 0.1 + 0.9 * 0.25 + (1.0, 1.0), # deadzone + (1-deadzone) * 1^2 = 0.1 + 0.9 * 1 + ], +) +def test_velocity_to_pwm(velocity, expected_pwm, protocol_handler): + """Test _velocity_to_pwm conversion.""" + protocol_handler.connect() + + pwm = protocol_handler._velocity_to_pwm(velocity) + assert abs(pwm - expected_pwm) < 0.01 # Allow small floating point differences + + +def test_controller_connect_disconnect(pwm_config, dummy_motors, mock_gpiozero): + """Test controller connect and disconnect.""" + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors) + controller.connect() + + assert controller.is_connected is True + assert controller.protocol_handler is not None + + controller.disconnect() + assert controller.is_connected is False + + +def test_controller_get_set_position(pwm_config, dummy_motors, mock_gpiozero): + """Test controller position methods.""" + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors) + controller.connect() + + controller.set_position("motor_1", 0.5) + position = controller.get_position("motor_1") + assert position == 0.5 + + +def test_controller_get_set_velocity(pwm_config, dummy_motors, mock_gpiozero): + """Test controller velocity methods.""" + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors) + controller.connect() + + controller.set_velocity("motor_1", 0.5) + velocity = controller.get_velocity("motor_1") + assert velocity == 0.5 + + +def test_controller_get_set_pwm(pwm_config, dummy_motors, mock_gpiozero): + """Test controller PWM methods.""" + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors) + controller.connect() + + # Set direction first + controller.protocol_handler.motor_states[1]["direction"] = 1 + + controller.set_pwm("motor_1", 0.5) + pwm = controller.get_pwm("motor_1") + assert pwm == 0.5 + + +def test_controller_enable_disable(pwm_config, dummy_motors, mock_gpiozero): + """Test controller enable/disable methods.""" + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + controller = PWMDCMotorsController(config=pwm_config, motors=dummy_motors) + controller.connect() + + controller.enable_motor("motor_1") + assert controller.protocol_handler.motor_states[1]["enabled"] is True + + controller.disable_motor("motor_1") + assert controller.protocol_handler.motor_states[1]["enabled"] is False + + +def test_setup_pwmled_fallback(protocol_handler): + """Test _setup_pwmled falls back to default frequency on error.""" + protocol_handler.connect() + + # This test verifies the fallback logic exists + # The actual fallback is tested by the fact that connect() works + assert len(protocol_handler.in1_channels) > 0 + + +def test_validate_pi5_pins(protocol_handler): + """Test pin validation.""" + # This should not raise for valid pins + protocol_handler._validate_pi5_pins() + + # Test with invalid IN2 pin + protocol_handler.in2_pins = [999] + # Should log warning but not raise + protocol_handler._validate_pi5_pins() + + +def test_invert_direction(pwm_config, dummy_motors, mock_gpiozero): + """Test direction inversion configuration.""" + pwm_config["invert_direction"] = True + + with patch.dict(sys.modules, {"gpiozero": mock_gpiozero}): + handler = PWMProtocolHandler(pwm_config, dummy_motors) + handler.connect() + handler.motor_states[1]["direction"] = 1 + + # Test _set_direction with inversion + handler._set_direction(1, forward=True) + # Should be inverted + assert handler.motor_states[1]["direction"] == -1