diff --git a/openadapt/build.py b/openadapt/build.py index 1b0a82a45..1fefe8a9c 100644 --- a/openadapt/build.py +++ b/openadapt/build.py @@ -16,7 +16,7 @@ import urllib.request import gradio_client -import oa_pynput +import pynput import pycocotools import pydicom import pyqttoast @@ -34,7 +34,7 @@ def build_pyinstaller() -> None: """Build the application using PyInstaller.""" additional_packages_to_install = [ - oa_pynput, + pynput, pydicom, spacy_alignments, gradio_client, @@ -275,6 +275,8 @@ def main() -> None: create_macos_dmg() elif sys.platform == "win32": create_windows_installer() + else: + print(f"WARNING: openadapt.build is not yet supported on {sys.platform=}") if __name__ == "__main__": diff --git a/openadapt/build_utils.py b/openadapt/build_utils.py index 0cde71576..0d7228ad3 100644 --- a/openadapt/build_utils.py +++ b/openadapt/build_utils.py @@ -17,13 +17,15 @@ def get_root_dir_path() -> pathlib.Path: if not path.exists(): path.mkdir(parents=True, exist_ok=True) return path - else: + elif sys.platform == "win32": # if windows, get the path to the %APPDATA% directory and set the path # for all user preferences path = pathlib.Path.home() / "AppData" / "Roaming" / "openadapt" if not path.exists(): path.mkdir(parents=True, exist_ok=True) return path + else: + print(f"WARNING: openadapt.build_utils is not yet supported on {sys.platform=}") def is_running_from_executable() -> bool: diff --git a/openadapt/capture/__init__.py b/openadapt/capture/__init__.py index 756bed34f..243c00c48 100644 --- a/openadapt/capture/__init__.py +++ b/openadapt/capture/__init__.py @@ -9,6 +9,8 @@ from . import _macos as impl elif sys.platform == "win32": from . import _windows as impl +elif sys.platform.startswith("linux"): + from . import _linux as impl else: raise Exception(f"Unsupported platform: {sys.platform}") diff --git a/openadapt/capture/_linux.py b/openadapt/capture/_linux.py new file mode 100644 index 000000000..20475f182 --- /dev/null +++ b/openadapt/capture/_linux.py @@ -0,0 +1,140 @@ +import subprocess +import os +from datetime import datetime +from sys import platform +import pyaudio +import wave + +from openadapt.config import CAPTURE_DIR_PATH + + +class Capture: + """Capture the screen, audio, and camera on Linux.""" + + def __init__(self) -> None: + """Initialize the capture object.""" + if not platform.startswith("linux"): + assert platform == "linux", platform + + self.is_recording = False + self.audio_out = None + self.video_out = None + self.audio_stream = None + self.audio_frames = [] + + # Initialize PyAudio + self.audio = pyaudio.PyAudio() + + def get_screen_resolution(self) -> tuple: + """Get the screen resolution dynamically using xrandr.""" + try: + # Get screen resolution using xrandr + output = subprocess.check_output( + "xrandr | grep '*' | awk '{print $1}'", shell=True + ) + resolution = output.decode("utf-8").strip() + width, height = resolution.split("x") + return int(width), int(height) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to get screen resolution: {e}") + + def start(self, audio: bool = True, camera: bool = False) -> None: + """Start capturing the screen, audio, and camera. + + Args: + audio (bool, optional): Whether to capture audio (default: True). + camera (bool, optional): Whether to capture the camera (default: False). + """ + if self.is_recording: + raise RuntimeError("Recording is already in progress") + + self.is_recording = True + capture_dir = CAPTURE_DIR_PATH + if not os.path.exists(capture_dir): + os.mkdir(capture_dir) + + # Get the screen resolution dynamically + screen_width, screen_height = self.get_screen_resolution() + + # Start video capture using ffmpeg + video_filename = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".mp4" + self.video_out = os.path.join(capture_dir, video_filename) + self._start_video_capture(screen_width, screen_height) + + # Start audio capture + if audio: + audio_filename = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".wav" + self.audio_out = os.path.join(capture_dir, audio_filename) + self._start_audio_capture() + + def _start_video_capture(self, width: int, height: int) -> None: + """Start capturing the screen using ffmpeg with the dynamic resolution.""" + cmd = [ + "ffmpeg", + "-f", + "x11grab", # Capture X11 display + "-video_size", + f"{width}x{height}", # Use dynamic screen resolution + "-framerate", + "30", # Set frame rate + "-i", + ":0.0", # Capture from display 0 + "-c:v", + "libx264", # Video codec + "-preset", + "ultrafast", # Speed/quality tradeoff + "-y", + self.video_out, # Output file + ] + self.video_proc = subprocess.Popen(cmd) + + def _start_audio_capture(self) -> None: + """Start capturing audio using PyAudio.""" + self.audio_stream = self.audio.open( + format=pyaudio.paInt16, + channels=2, + rate=44100, + input=True, + frames_per_buffer=1024, + stream_callback=self._audio_callback, + ) + self.audio_frames = [] + self.audio_stream.start_stream() + + def _audio_callback( + self, in_data: bytes, frame_count: int, time_info: dict, status: int + ) -> tuple: + """Callback function to process audio data.""" + self.audio_frames.append(in_data) + return (None, pyaudio.paContinue) + + def stop(self) -> None: + """Stop capturing the screen, audio, and camera.""" + if self.is_recording: + # Stop the video capture + self.video_proc.terminate() + + # Stop audio capture + if self.audio_stream: + self.audio_stream.stop_stream() + self.audio_stream.close() + self.audio.terminate() + self.save_audio() + + self.is_recording = False + + def save_audio(self) -> None: + """Save the captured audio to a WAV file.""" + if self.audio_out: + with wave.open(self.audio_out, "wb") as wf: + wf.setnchannels(2) + wf.setsampwidth(self.audio.get_sample_size(pyaudio.paInt16)) + wf.setframerate(44100) + wf.writeframes(b"".join(self.audio_frames)) + + +if __name__ == "__main__": + capture = Capture() + capture.start(audio=True, camera=False) + input("Press enter to stop") + capture.stop() diff --git a/openadapt/capture/_macos.py b/openadapt/capture/_macos.py index 65529910c..d2c00e7d4 100644 --- a/openadapt/capture/_macos.py +++ b/openadapt/capture/_macos.py @@ -23,10 +23,7 @@ class Capture: def __init__(self) -> None: """Initialize the capture object.""" - if platform != "darwin": - raise NotImplementedError( - "This is the macOS implementation, please use the Windows version" - ) + assert platform == "darwin", platform objc.options.structs_indexable = True diff --git a/openadapt/capture/_windows.py b/openadapt/capture/_windows.py index ab400c950..ad09e48b1 100644 --- a/openadapt/capture/_windows.py +++ b/openadapt/capture/_windows.py @@ -21,10 +21,8 @@ def __init__(self, pid: int = 0) -> None: pid (int, optional): The process ID of the window to capture. Defaults to 0 (the entire screen) """ - if platform != "win32": - raise NotImplementedError( - "This is the Windows implementation, please use the macOS version" - ) + assert platform == "win32", platform + self.is_recording = False self.video_out = None self.audio_out = None diff --git a/openadapt/models.py b/openadapt/models.py index b2286b812..2f42f17fb 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -9,7 +9,7 @@ import sys from bs4 import BeautifulSoup -from oa_pynput import keyboard +from pynput import keyboard from PIL import Image, ImageChops import numpy as np import sqlalchemy as sa @@ -649,9 +649,9 @@ def to_prompt_dict( "title", "help", ] - if sys.platform == "win32": + if sys.platform != "darwin": logger.warning( - "key_suffixes have not yet been defined on Windows." + "key_suffixes have not yet been defined on {sys.platform=}." "You can help by uncommenting the lines below and pasting " "the contents of the window_dict into a new GitHub Issue." ) diff --git a/openadapt/playback.py b/openadapt/playback.py index c909ee535..388b709b5 100644 --- a/openadapt/playback.py +++ b/openadapt/playback.py @@ -1,6 +1,6 @@ """Utilities for playing back ActionEvents.""" -from oa_pynput import keyboard, mouse +from pynput import keyboard, mouse from openadapt.common import KEY_EVENTS, MOUSE_EVENTS from openadapt.custom_logger import logger diff --git a/openadapt/plotting.py b/openadapt/plotting.py index 03c6a5b0c..06eacb612 100644 --- a/openadapt/plotting.py +++ b/openadapt/plotting.py @@ -435,8 +435,10 @@ def plot_performance( if view_file: if sys.platform == "darwin": os.system(f"open {fpath}") - else: + elif sys.platform == "win32": os.system(f"start {fpath}") + else: + os.system(f"xdg-open {fpath}") else: plt.savefig(BytesIO(), format="png") # save fig to void if view_file: diff --git a/openadapt/record.py b/openadapt/record.py index eef25c7c8..4740da8c9 100644 --- a/openadapt/record.py +++ b/openadapt/record.py @@ -20,7 +20,7 @@ import time import tracemalloc -from oa_pynput import keyboard, mouse +from pynput import keyboard, mouse from pympler import tracker import av @@ -65,6 +65,9 @@ stop_sequence_detected = False ws_server_instance = None +# TODO XXX replace with utils.get_monitor_dims() once fixed +monitor_width, monitor_height = utils.take_screenshot().size + def collect_stats(performance_snapshots: list[tracemalloc.Snapshot]) -> None: """Collects and appends performance snapshots using tracemalloc. @@ -138,7 +141,7 @@ def process_events( perf_q: sq.SynchronizedQueue, recording: Recording, terminate_processing: multiprocessing.Event, - started_counter: multiprocessing.Value, + started_event: threading.Event, num_screen_events: multiprocessing.Value, num_action_events: multiprocessing.Value, num_window_events: multiprocessing.Value, @@ -157,7 +160,7 @@ def process_events( perf_q: A queue for collecting performance data. recording: The recording object. terminate_processing: An event to signal the termination of the process. - started_counter: Value to increment once started. + started_event: Event to set once started. num_screen_events: A counter for the number of screen events. num_action_events: A counter for the number of action events. num_window_events: A counter for the number of window events. @@ -177,8 +180,7 @@ def process_events( while not terminate_processing.is_set() or not event_q.empty(): event = event_q.get() if not started: - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() started = True logger.trace(f"{event=}") assert event.type in EVENT_TYPES, event @@ -371,7 +373,7 @@ def write_events( perf_q: sq.SynchronizedQueue, recording: Recording, terminate_processing: multiprocessing.Event, - started_counter: multiprocessing.Value, + started_event: multiprocessing.Event, pre_callback: Callable[[float], dict] | None = None, post_callback: Callable[[dict], None] | None = None, ) -> None: @@ -385,7 +387,7 @@ def write_events( perf_q: A queue for collecting performance data. recording: The recording object. terminate_processing: An event to signal the termination of the process. - started_counter: Value to increment once started. + started_event: Event to increment once started. pre_callback: Optional function to call before main loop. Takes recording timestamp as only argument, returns a state dict. post_callback: Optional function to call after main loop. Takes state dict as @@ -422,8 +424,7 @@ def write_events( for _ in range(num_processed): progress.update() if not started: - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() started = True try: event = write_q.get_nowait() @@ -460,10 +461,8 @@ def video_pre_callback(db: crud.SaSession, recording: Recording) -> dict[str, An dict[str, Any]: The updated state. """ video_file_path = video.get_video_file_path(recording.timestamp) - # TODO XXX replace with utils.get_monitor_dims() once fixed - width, height = utils.take_screenshot().size video_container, video_stream, video_start_timestamp = ( - video.initialize_video_writer(video_file_path, width, height) + video.initialize_video_writer(video_file_path, monitor_width, monitor_height) ) crud.update_video_start_time(db, recording, video_start_timestamp) return { @@ -577,7 +576,7 @@ def trigger_action_event( event_q.put(Event(utils.get_timestamp(), "action", action_event_args)) -def on_move(event_q: queue.Queue, x: int, y: int, injected: bool) -> None: +def on_move(event_q: queue.Queue, x: int, y: int, injected: bool = False) -> None: """Handles the 'move' event. Args: @@ -603,7 +602,7 @@ def on_click( y: int, button: mouse.Button, pressed: bool, - injected: bool, + injected: bool = False, ) -> None: """Handles the 'click' event. @@ -638,7 +637,7 @@ def on_scroll( y: int, dx: int, dy: int, - injected: bool, + injected: bool = False, ) -> None: """Handles the 'scroll' event. @@ -705,7 +704,7 @@ def read_screen_events( event_q: queue.Queue, terminate_processing: multiprocessing.Event, recording: Recording, - started_counter: multiprocessing.Value, + started_event: threading.Event, # TODO: throttle # max_cpu_percent: float = 50.0, # Maximum allowed CPU percent # max_memory_percent: float = 50.0, # Maximum allowed memory percent @@ -717,7 +716,7 @@ def read_screen_events( event_q: A queue for adding screen events. terminate_processing: An event to signal the termination of the process. recording: The recording object. - started_counter: Value to increment once started. + started_event: Event to set once started. """ utils.set_start_time(recording.timestamp) @@ -729,8 +728,7 @@ def read_screen_events( logger.warning("Screenshot was None") continue if not started: - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() started = True event_q.put(Event(utils.get_timestamp(), "screen", screenshot)) logger.info("Done") @@ -741,7 +739,7 @@ def read_window_events( event_q: queue.Queue, terminate_processing: multiprocessing.Event, recording: Recording, - started_counter: multiprocessing.Value, + started_event: threading.Event, ) -> None: """Read window events and add them to the event queue. @@ -749,7 +747,7 @@ def read_window_events( event_q: A queue for adding window events. terminate_processing: An event to signal the termination of the process. recording: The recording object. - started_counter: Value to increment once started. + started_event: Event to set once started. """ utils.set_start_time(recording.timestamp) @@ -762,8 +760,7 @@ def read_window_events( continue if not started: - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() started = True if window_data["title"] != prev_window_data.get("title") or window_data[ @@ -796,7 +793,7 @@ def performance_stats_writer( perf_q: sq.SynchronizedQueue, recording: Recording, terminate_processing: multiprocessing.Event, - started_counter: multiprocessing.Value, + started_event: multiprocessing.Event, ) -> None: """Write performance stats to the database. @@ -806,7 +803,7 @@ def performance_stats_writer( perf_q: A queue for collecting performance data. recording: The recording object. terminate_processing: An event to signal the termination of the process. - started_counter: Value to increment once started. + started_event: Event to set once started. """ utils.set_start_time(recording.timestamp) @@ -816,8 +813,7 @@ def performance_stats_writer( session = crud.get_new_session(read_and_write=True) while not terminate_processing.is_set() or not perf_q.empty(): if not started: - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() started = True try: event_type, start_time, end_time = perf_q.get_nowait() @@ -838,7 +834,7 @@ def memory_writer( recording: Recording, terminate_processing: multiprocessing.Event, record_pid: int, - started_counter: multiprocessing.Value, + started_event: multiprocessing.Event, ) -> None: """Writes memory usage statistics to the database. @@ -847,7 +843,7 @@ def memory_writer( terminate_processing (multiprocessing.Event): The event used to terminate the process. record_pid (int): The process ID to monitor memory usage for. - started_counter: Value to increment once started. + started_event: Event to set once started. Returns: None @@ -862,8 +858,7 @@ def memory_writer( session = crud.get_new_session(read_and_write=True) while not terminate_processing.is_set(): if not started: - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() started = True memory_usage_bytes = 0 @@ -928,7 +923,7 @@ def read_keyboard_events( event_q: queue.Queue, terminate_processing: multiprocessing.Event, recording: Recording, - started_counter: multiprocessing.Value, + started_event: threading.Event, ) -> None: """Reads keyboard events and adds them to the event queue. @@ -937,7 +932,7 @@ def read_keyboard_events( terminate_processing (multiprocessing.Event): The event to signal termination of event reading. recording (Recording): The recording object. - started_counter: Value to increment once started. + started_event: Event to set once started. Returns: None @@ -949,7 +944,7 @@ def read_keyboard_events( def on_press( event_q: queue.Queue, key: keyboard.Key | keyboard.KeyCode, - injected: bool, + injected: bool = False, ) -> None: """Event handler for key press events. @@ -1000,7 +995,7 @@ def on_press( def on_release( event_q: queue.Queue, key: keyboard.Key | keyboard.KeyCode, - injected: bool, + injected: bool = False, ) -> None: """Event handler for key release events. @@ -1028,8 +1023,7 @@ def on_release( # NOTE: listener may not have actually started by now # TODO: handle race condition, e.g. by sending synthetic events from main thread - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() terminate_processing.wait() keyboard_listener.stop() @@ -1039,7 +1033,7 @@ def read_mouse_events( event_q: queue.Queue, terminate_processing: multiprocessing.Event, recording: Recording, - started_counter: multiprocessing.Value, + started_event: threading.Event, ) -> None: """Reads mouse events and adds them to the event queue. @@ -1047,7 +1041,7 @@ def read_mouse_events( event_q: The event queue to add the mouse events to. terminate_processing: The event to signal termination of event reading. recording: The recording object. - started_counter: Value to increment once started. + started_event: Event to set once started. Returns: None @@ -1063,8 +1057,7 @@ def read_mouse_events( # NOTE: listener may not have actually started by now # TODO: handle race condition, e.g. by sending synthetic events from main thread - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() terminate_processing.wait() mouse_listener.stop() @@ -1073,14 +1066,14 @@ def read_mouse_events( def record_audio( recording: Recording, terminate_processing: multiprocessing.Event, - started_counter: multiprocessing.Value, + started_event: multiprocessing.Event, ) -> None: """Record audio narration during the recording and store data in database. Args: recording: The recording object. terminate_processing: An event to signal the termination of the process. - started_counter: Value to increment once started. + started_event: Event to set once started. """ utils.configure_logging(logger, LOG_LEVEL) utils.set_start_time(recording.timestamp) @@ -1110,8 +1103,7 @@ def audio_callback( # NOTE: listener may not have actually started by now # TODO: handle race condition, e.g. by sending synthetic events from main thread - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() terminate_processing.wait() audio_stream.stop() @@ -1222,7 +1214,7 @@ def run_browser_event_server( event_q: queue.Queue, terminate_processing: Event, recording: Recording, - started_counter: multiprocessing.Value, + started_event: threading.Event, ) -> None: """Run the browser event server. @@ -1230,7 +1222,7 @@ def run_browser_event_server( event_q: A queue for adding browser events. terminate_processing: An event to signal the termination of the process. recording: The recording object. - started_counter: Value to increment once started. + started_event: Event to set once started. Returns: None @@ -1253,8 +1245,7 @@ def run_server() -> None: ) as server: ws_server_instance = server logger.info("WebSocket server started") - with started_counter.get_lock(): - started_counter.value += 1 + started_event.set() server.serve_forever() # Start the server in a separate thread @@ -1327,12 +1318,17 @@ def record( perf_q = sq.SynchronizedQueue() if terminate_processing is None: terminate_processing = multiprocessing.Event() - started_counter = multiprocessing.Value("i", 0) task_by_name = {} + task_started_events = {} window_event_reader = threading.Thread( target=read_window_events, - args=(event_q, terminate_processing, recording, started_counter), + args=( + event_q, + terminate_processing, + recording, + task_started_events.setdefault("window_event_reader", threading.Event()), + ), ) window_event_reader.start() task_by_name["window_event_reader"] = window_event_reader @@ -1340,28 +1336,50 @@ def record( if config.RECORD_BROWSER_EVENTS: browser_event_reader = threading.Thread( target=run_browser_event_server, - args=(event_q, terminate_processing, recording, started_counter), + args=( + event_q, + terminate_processing, + recording, + task_started_events.setdefault( + "browser_event_reader", threading.Event() + ), + ), ) browser_event_reader.start() task_by_name["browser_event_reader"] = browser_event_reader screen_event_reader = threading.Thread( target=read_screen_events, - args=(event_q, terminate_processing, recording, started_counter), + args=( + event_q, + terminate_processing, + recording, + task_started_events.setdefault("screen_event_reader", threading.Event()), + ), ) screen_event_reader.start() task_by_name["screen_event_reader"] = screen_event_reader keyboard_event_reader = threading.Thread( target=read_keyboard_events, - args=(event_q, terminate_processing, recording, started_counter), + args=( + event_q, + terminate_processing, + recording, + task_started_events.setdefault("keyboard_event_reader", threading.Event()), + ), ) keyboard_event_reader.start() task_by_name["keyboard_event_reader"] = keyboard_event_reader mouse_event_reader = threading.Thread( target=read_mouse_events, - args=(event_q, terminate_processing, recording, started_counter), + args=( + event_q, + terminate_processing, + recording, + task_started_events.setdefault("mouse_event_reader", threading.Event()), + ), ) mouse_event_reader.start() task_by_name["mouse_event_reader"] = mouse_event_reader @@ -1384,7 +1402,7 @@ def record( perf_q, recording, terminate_processing, - started_counter, + task_started_events.setdefault("event_processor", threading.Event()), num_screen_events, num_action_events, num_window_events, @@ -1405,7 +1423,9 @@ def record( perf_q, recording, terminate_processing, - started_counter, + task_started_events.setdefault( + "screen_event_writer", multiprocessing.Event() + ), ), ) screen_event_writer.start() @@ -1422,7 +1442,9 @@ def record( perf_q, recording, terminate_processing, - started_counter, + task_started_events.setdefault( + "browser_event_writer", multiprocessing.Event() + ), ), ) browser_event_writer.start() @@ -1438,7 +1460,9 @@ def record( perf_q, recording, terminate_processing, - started_counter, + task_started_events.setdefault( + "action_event_writer", multiprocessing.Event() + ), ), ) action_event_writer.start() @@ -1454,7 +1478,9 @@ def record( perf_q, recording, terminate_processing, - started_counter, + task_started_events.setdefault( + "window_event_writer", multiprocessing.Event() + ), ), ) window_event_writer.start() @@ -1471,7 +1497,7 @@ def record( perf_q, recording, terminate_processing, - started_counter, + task_started_events.setdefault("video_writer", multiprocessing.Event()), video_pre_callback, video_post_callback, ), @@ -1485,7 +1511,9 @@ def record( args=( recording, terminate_processing, - started_counter, + task_started_events.setdefault( + "audio_event_writer", multiprocessing.Event() + ), ), ) audio_recorder.start() @@ -1498,7 +1526,9 @@ def record( perf_q, recording, terminate_perf_event, - started_counter, + task_started_events.setdefault( + "perf_stats_writer", multiprocessing.Event() + ), ), ) perf_stats_writer.start() @@ -1512,7 +1542,7 @@ def record( recording, terminate_perf_event, record_pid, - started_counter, + task_started_events.setdefault("mem_writer", multiprocessing.Event()), ), ) mem_writer.start() @@ -1530,9 +1560,16 @@ def record( expected_starts = len(task_by_name) logger.info(f"{expected_starts=}") while True: - if started_counter.value >= expected_starts: + started_tasks = sum(event.is_set() for event in task_started_events.values()) + if started_tasks >= expected_starts: break - time.sleep(0.1) # Sleep to reduce busy waiting + waiting_for = [ + task for task, event in task_started_events.items() if not event.is_set() + ] + logger.info(f"Waiting for tasks to start: {waiting_for}") + logger.info(f"Started tasks: {started_tasks}/{expected_starts}") + time.sleep(1) # Sleep to reduce busy waiting + for _ in range(5): logger.info("*" * 40) logger.info("All readers and writers have started. Waiting for input events...") diff --git a/openadapt/strategies/base.py b/openadapt/strategies/base.py index de114331b..98fd6dc71 100644 --- a/openadapt/strategies/base.py +++ b/openadapt/strategies/base.py @@ -4,7 +4,7 @@ from pprint import pformat import time -from oa_pynput import keyboard, mouse +from pynput import keyboard, mouse import numpy as np from openadapt import adapters, models, playback, utils diff --git a/openadapt/utils.py b/openadapt/utils.py index 9fbfd720c..a457dc2a5 100644 --- a/openadapt/utils.py +++ b/openadapt/utils.py @@ -12,6 +12,8 @@ import importlib.metadata import inspect import os +import re +import subprocess import sys import threading import time @@ -20,6 +22,7 @@ from jinja2 import Environment, FileSystemLoader from PIL import Image, ImageEnhance from posthog import Posthog +import multiprocessing_utils from openadapt.build_utils import is_running_from_executable, redirect_stdout_stderr from openadapt.custom_logger import logger @@ -52,12 +55,17 @@ # TODO: move to constants.py EMPTY = (None, [], {}, "") -SCT = mss.mss() +# TODO: move to config.py +DEFAULT_DOUBLE_CLICK_INTERVAL_SECONDS = 0.5 +DEFAULT_DOUBLE_CLICK_DISTANCE_PIXELS = 5 _logger_lock = threading.Lock() _start_time = None _start_perf_counter = None +# Process-local storage for MSS instances +_process_local = multiprocessing_utils.local() + def configure_logging(logger: logger, log_level: str) -> None: """Configure the logging settings for OpenAdapt. @@ -214,6 +222,22 @@ def override_double_click_interval_seconds( get_double_click_interval_seconds.override_value = override_value +def get_linux_setting(gnome_command: str, kde_command: str, default_value: int) -> int: + """Try to get a setting from GNOME or KDE, falling back to a default value.""" + try: + # Try GNOME first + output = subprocess.check_output(gnome_command, shell=True).decode().strip() + return int(output) + except (subprocess.CalledProcessError, ValueError): + try: + # If GNOME fails, try KDE + output = subprocess.check_output(kde_command, shell=True).decode().strip() + return int(output) + except (subprocess.CalledProcessError, ValueError): + # If both fail, return the default value + return default_value + + def get_double_click_interval_seconds() -> float: """Get the double click interval in seconds. @@ -231,8 +255,56 @@ def get_double_click_interval_seconds() -> float: from ctypes import windll return windll.user32.GetDoubleClickTime() / 1000 + elif sys.platform.startswith("linux"): + gnome_cmd = "gsettings get org.gnome.desktop.peripherals.mouse double-click" + kde_cmd = "kreadconfig5 --group KDE --key DoubleClickInterval" + value = get_linux_setting( + gnome_cmd, kde_cmd, DEFAULT_DOUBLE_CLICK_INTERVAL_SECONDS * 1000 + ) + return value / 1000 # Convert from milliseconds to seconds else: - raise Exception(f"Unsupported {sys.platform=}") + raise Exception(f"Unsupported platform: {sys.platform}") + + +def get_linux_device_id(device_name: str) -> int | None: + """Get the device ID for a device containing the given name. + + Args: + device_name (str): The name to search for in device listings. + + Returns: + Optional[int]: The device ID if found, None otherwise. + """ + try: + output = subprocess.check_output(["xinput", "list"], text=True) + match = re.search(f"\\b{re.escape(device_name)}\\b.*?id=(\\d+)", output) + if match: + return int(match.group(1)) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + return None + + +def get_xinput_property(device_id: int, property_name: str) -> int | None: + """Get a specific property value from xinput for a given device. + + Args: + device_id (int): The ID of the device. + property_name (str): The name of the property to search for. + + Returns: + Optional[int]: The property value if found, None otherwise. + """ + try: + output = subprocess.check_output( + ["xinput", "list-props", str(device_id)], text=True + ) + match = re.search(rf"{property_name} \((\d+)\):\s+(\d+)", output) + if match: + return int(match.group(2)) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + return None def get_double_click_distance_pixels() -> int: @@ -259,8 +331,22 @@ def get_double_click_distance_pixels() -> int: if x != y: logger.warning(f"{x=} != {y=}") return max(x, y) + elif sys.platform.startswith("linux"): + device_id = get_linux_device_id("Mouse") + if device_id is not None: + value = get_xinput_property(device_id, "libinput Scrolling Pixel Distance") + if value is not None: + return value + return DEFAULT_DOUBLE_CLICK_DISTANCE_PIXELS else: - raise Exception(f"Unsupported {sys.platform=}") + raise Exception(f"Unsupported platform: {sys.platform}") + + +def get_process_local_sct() -> mss.mss: + """Retrieve or create the `mss` instance for the current thread.""" + if not hasattr(_process_local, "sct"): + _process_local.sct = mss.mss() + return _process_local.sct def get_monitor_dims() -> tuple[int, int]: @@ -270,7 +356,7 @@ def get_monitor_dims() -> tuple[int, int]: tuple[int, int]: The width and height of the monitor. """ # TODO XXX: replace with get_screenshot().size and remove get_scale_ratios? - monitor = SCT.monitors[0] + monitor = get_process_local_sct().monitors[0] monitor_width = monitor["width"] monitor_height = monitor["height"] return monitor_width, monitor_height @@ -420,8 +506,9 @@ def take_screenshot() -> Image.Image: PIL.Image: The screenshot image. """ # monitor 0 is all in one - monitor = SCT.monitors[0] - sct_img = SCT.grab(monitor) + sct = get_process_local_sct() + monitor = sct.monitors[0] + sct_img = sct.grab(monitor) image = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX") return image diff --git a/openadapt/window/__init__.py b/openadapt/window/__init__.py index fe0cb9e9f..f146d367a 100644 --- a/openadapt/window/__init__.py +++ b/openadapt/window/__init__.py @@ -11,9 +11,10 @@ if sys.platform == "darwin": from . import _macos as impl -elif sys.platform in ("win32", "linux"): - # TODO: implement Linux +elif sys.platform == "win32": from . import _windows as impl +elif sys.platform.startswith("linux"): + from . import _linux as impl else: raise Exception(f"Unsupported platform: {sys.platform}") @@ -59,6 +60,7 @@ def get_active_window_state(read_window_data: bool) -> dict | None: or None if the state is not available. """ # TODO: save window identifier (a window's title can change, or + # multiple windows can have the same title) try: return impl.get_active_window_state(read_window_data) except Exception as exc: diff --git a/openadapt/window/_linux.py b/openadapt/window/_linux.py new file mode 100644 index 000000000..dbf329dc1 --- /dev/null +++ b/openadapt/window/_linux.py @@ -0,0 +1,183 @@ +import xcffib +import xcffib.xproto +import pickle +import time + +from openadapt.custom_logger import logger + +# Global X server connection +_conn = None + + +def get_x_server_connection() -> xcffib.Connection: + """Get or create a global connection to the X server. + + Returns: + xcffib.Connection: A global connection object. + """ + global _conn + if _conn is None: + _conn = xcffib.connect() + return _conn + + +def get_active_window_meta() -> dict | None: + """Retrieve metadata of the active window using a persistent X server connection. + + Returns: + dict or None: A dictionary containing metadata of the active window. + """ + try: + conn = get_x_server_connection() + root = conn.get_setup().roots[0].root + + # Get the _NET_ACTIVE_WINDOW atom + atom = ( + conn.core.InternAtom(False, len("_NET_ACTIVE_WINDOW"), "_NET_ACTIVE_WINDOW") + .reply() + .atom + ) + + # Fetch the active window ID + active_window = conn.core.GetProperty( + False, root, atom, xcffib.xproto.Atom.WINDOW, 0, 1 + ).reply() + if not active_window.value_len: + return None + + # Convert the value to a proper bytes object + window_id_bytes = b"".join(active_window.value) # Concatenate bytes + window_id = int.from_bytes(window_id_bytes, byteorder="little") + + # Get window geometry + geom = conn.core.GetGeometry(window_id).reply() + + return { + "window_id": window_id, + "x": geom.x, + "y": geom.y, + "width": geom.width, + "height": geom.height, + "title": get_window_title(conn, window_id), + } + except Exception as exc: + logger.warning(f"Failed to retrieve active window metadata: {exc}") + return None + + +def get_window_title(conn: xcffib.Connection, window_id: int) -> str: + """Retrieve the title of a given window. + + Args: + conn (xcffib.Connection): X server connection. + window_id (int): The ID of the window. + + Returns: + str: The title of the window, or an empty string if unavailable. + """ + try: + # Attempt to fetch _NET_WM_NAME + atom_net_wm_name = ( + conn.core.InternAtom(False, len("_NET_WM_NAME"), "_NET_WM_NAME") + .reply() + .atom + ) + title_property = conn.core.GetProperty( + False, window_id, atom_net_wm_name, xcffib.xproto.Atom.STRING, 0, 1024 + ).reply() + if title_property.value_len > 0: + title_bytes = b"".join(title_property.value) # Convert using b"".join() + return title_bytes.decode("utf-8") + + # Fallback to WM_NAME + atom_wm_name = ( + conn.core.InternAtom(False, len("WM_NAME"), "WM_NAME").reply().atom + ) + title_property = conn.core.GetProperty( + False, window_id, atom_wm_name, xcffib.xproto.Atom.STRING, 0, 1024 + ).reply() + if title_property.value_len > 0: + title_bytes = b"".join(title_property.value) # Convert using b"".join() + return title_bytes.decode("utf-8") + except Exception as exc: + logger.warning(f"Failed to retrieve window title: {exc}") + return "" + + +def get_active_window_state(read_window_data: bool) -> dict | None: + """Get the state of the active window. + + Args: + read_window_data (bool): Whether to include detailed data about the window. + + Returns: + dict or None: A dictionary containing the state of the active window. + """ + meta = get_active_window_meta() + if not meta: + return None + + if read_window_data: + data = get_window_data(meta) + else: + data = {} + + state = { + "title": meta.get("title", ""), + "left": meta.get("x", 0), + "top": meta.get("y", 0), + "width": meta.get("width", 0), + "height": meta.get("height", 0), + "window_id": meta.get("window_id", 0), + "meta": meta, + "data": data, + } + try: + pickle.dumps(state, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as exc: + logger.warning(f"{exc=}") + state.pop("data") + return state + + +def get_window_data(meta: dict) -> dict: + """Retrieve detailed data for the active window. + + Args: + meta (dict): Metadata of the active window. + + Returns: + dict: Detailed data of the window. + """ + # TODO: implement, e.g. with pyatspi + return {} + + +def get_active_element_state(x: int, y: int) -> dict | None: + """Get the state of the active element at the specified coordinates. + + Args: + x (int): The x-coordinate of the element. + y (int): The y-coordinate of the element. + + Returns: + dict or None: A dictionary containing the state of the active element. + """ + # Placeholder: Implement element-level state retrieval if necessary. + return {"x": x, "y": y, "state": "placeholder"} + + +def main() -> None: + """Test function for retrieving and inspecting the state of the active window.""" + time.sleep(1) + + state = get_active_window_state(read_window_data=True) + print(state) + pickle.dumps(state, protocol=pickle.HIGHEST_PROTOCOL) + import ipdb + + ipdb.set_trace() # noqa: E702 + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index 9cf91a77a..7adb445b4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -3915,6 +3915,23 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} +[[package]] +name = "multiprocessing-utils" +version = "0.4" +description = "Multiprocessing utils (shared locks and thread locals)" +optional = false +python-versions = "*" +files = [ + {file = "multiprocessing_utils-0.4-py2.py3-none-any.whl", hash = "sha256:c232e0bbc6ba753ca7a0df5d49b0cc4e26454635d4f373f5133c551aec8c27ee"}, + {file = "multiprocessing_utils-0.4.tar.gz", hash = "sha256:43281d5e017d9b3f3e6114762c21f10c2a2b0392837c5096380e6c413ae79b6c"}, +] + +[package.dependencies] +six = "*" + +[package.extras] +test = ["tox"] + [[package]] name = "murmurhash" version = "1.0.10" @@ -4155,24 +4172,6 @@ dev = ["black", "pre-commit", "tox"] doc = ["m2r2", "sphinx"] test = ["pytest", "pytest-cov"] -[[package]] -name = "oa-pynput" -version = "1.7.7" -description = "Monitor and control user input devices" -optional = false -python-versions = "*" -files = [ - {file = "oa_pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:c1ee3d910d108fb216ddcac0ee01ab7b928f9a9307a47afda5b9a69fbb9da5a7"}, - {file = "oa_pynput-1.7.7.tar.gz", hash = "sha256:d20e2e93fee874dadc634b127e4a6278653fdf9165affe6eaf466e4e4807b9e8"}, -] - -[package.dependencies] -evdev = {version = ">=1.3", markers = "sys_platform in \"linux\""} -pyobjc-framework-ApplicationServices = {version = ">=8.0", markers = "sys_platform == \"darwin\""} -pyobjc-framework-Quartz = {version = ">=8.0", markers = "sys_platform == \"darwin\""} -python-xlib = {version = ">=0.17", markers = "sys_platform in \"linux\""} -six = "*" - [[package]] name = "onnxruntime" version = "1.20.0" @@ -5047,7 +5046,6 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs optional = false python-versions = ">=3.8" files = [ - {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] @@ -5058,7 +5056,6 @@ description = "A collection of ASN.1-based protocols modules" optional = false python-versions = ">=3.8" files = [ - {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, ] @@ -5572,6 +5569,23 @@ cffi = ">=1.4.1" docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] +[[package]] +name = "pynput" +version = "1.7.7" +description = "Monitor and control user input devices" +optional = false +python-versions = "*" +files = [ + {file = "pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:afc43f651684c98818de048abc76adf9f2d3d797083cb07c1f82be764a2d44cb"}, +] + +[package.dependencies] +evdev = {version = ">=1.3", markers = "sys_platform in \"linux\""} +pyobjc-framework-ApplicationServices = {version = ">=8.0", markers = "sys_platform == \"darwin\""} +pyobjc-framework-Quartz = {version = ">=8.0", markers = "sys_platform == \"darwin\""} +python-xlib = {version = ">=0.17", markers = "sys_platform in \"linux\""} +six = "*" + [[package]] name = "pyobjc-core" version = "10.3.1" @@ -8971,6 +8985,19 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] +[[package]] +name = "xcffib" +version = "1.5.0" +description = "A drop in replacement for xpyb, an XCB python binding" +optional = false +python-versions = "*" +files = [ + {file = "xcffib-1.5.0.tar.gz", hash = "sha256:a95c9465f2f97b4fcede606bd1e08407a32df71cb760fd57bfe53677db691acc"}, +] + +[package.dependencies] +cffi = ">=1.1.0" + [[package]] name = "yarl" version = "1.17.1" @@ -9124,4 +9151,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.10.x" -content-hash = "fbb9a8c0ac03708a131f06d1d3f7086d7718dacbf03d199b70e2df76e23640dd" +content-hash = "5ee4ca5a50f3fc9e3e2b43028b9b6dcb8318e58bf0e4cf01a4df32d9d5c425c2" diff --git a/pyproject.toml b/pyproject.toml index 009fa7b44..bd6410f9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ alembic = "1.8.1" black = "^24.8.0" pygetwindow = { version = "<0.0.5", markers = "sys_platform == 'win32'" } pywin32 = { version = "306", markers = "sys_platform == 'win32'" } +xcffib = { version = "1.5.0", markers = "sys_platform == 'linux'" } ascii-magic = "2.3.0" bokeh = "2.4.3" clipboard = "0.0.4" @@ -75,7 +76,6 @@ pyobjc-framework-avfoundation = { version = "^9.2", markers = "sys_platform == ' fastapi = "^0.111.1" screen-recorder-sdk = { version = "^1.3.0", markers = "sys_platform == 'win32'" } pyaudio = { version = "^0.2.13", markers = "sys_platform == 'win32'" } -oa-pynput = "^1.7.7" oa-atomacos = { version = "3.2.0", markers = "sys_platform == 'darwin'" } presidio-image-redactor = "^0.0.48" pywebview = "^4.2.2" @@ -109,6 +109,8 @@ tokencost = "^0.1.12" numba = "^0.60.0" llvmlite = "^0.43.0" ell-ai = "^0.0.14" +pynput = "^1.7.7" +multiprocessing-utils = "^0.4" [tool.pytest.ini_options] filterwarnings = [ # suppress warnings starting from "setuptools>=67.3"