diff --git a/.gitignore b/.gitignore
index 90ff2ae34..30c3caabc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,6 +17,7 @@ html/
venv/
venv*/
__pycache__/
+.pytest_cache/
.cache/
# BciPy files and directories
@@ -39,3 +40,4 @@ bcipy/simulator/tests/resource/
data/
bids/
!bcipy/simulator/data
+
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 48ca96c20..3d2f47480 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,7 @@ The next major BciPy release is here! All features included from release canidat
- Refactor `helpers` into `io` and `core` #362
- Refactor to use `pyproject.toml` for installs #367
- Language Model Refactor to use lm-toolkit #381 #390
+- Gaze Model integration #384
- Simulator
- Multimodal support #385
- Replay feature #376
@@ -25,6 +26,10 @@ The next major BciPy release is here! All features included from release canidat
- `EDFlib-Python` #362
- Remove
- `pyedflib` #362
+ - Drop support for python 3.8 #391
+- General documentation improvements
+ - README updates #391
+ - Drop Twitter links #391
# 2.0.1-rc.4
diff --git a/LICENSE.md b/LICENSE.md
index 907d40b62..341ee96fe 100644
--- a/LICENSE.md
+++ b/LICENSE.md
@@ -1,33 +1,11 @@
-BciPy Copyright 2021 (CAMBI)(“Licensor”)
+Copyright 2025 (CAMBI)("Licensor")
-Hippocratic License Version Number: 2.1.
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
-Purpose. The purpose of this License is for the Licensor named above to permit the Licensee (as defined below) broad permission, if consistent with Human Rights Laws and Human Rights Principles (as each is defined below), to use and work with the Software (as defined below) within the full scope of Licensor’s copyright and patent rights, if any, in the Software, while ensuring attribution and protecting the Licensor from liability.
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
-Permission and Conditions. The Licensor grants permission by this license (“License”), free of charge, to the extent of Licensor’s rights under applicable copyright and patent law, to any person or entity (the “Licensee”) obtaining a copy of this software and associated documentation files (the “Software”), to do everything with the Software that would otherwise infringe (i) the Licensor’s copyright in the Software or (ii) any patent claims to the Software that the Licensor can license or becomes able to license, subject to all of the following terms and conditions:
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
-* Acceptance. This License is automatically offered to every person and entity subject to its terms and conditions. Licensee accepts this License and agrees to its terms and conditions by taking any action with the Software that, absent this License, would infringe any intellectual property right held by Licensor.
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
-* Notice. Licensee must ensure that everyone who gets a copy of any part of this Software from Licensee, with or without changes, also receives the License and the above copyright notice (and if included by the Licensor, patent, trademark and attribution notice). Licensee must cause any modified versions of the Software to carry prominent notices stating that Licensee changed the Software. For clarity, although Licensee is free to create modifications of the Software and distribute only the modified portion created by Licensee with additional or different terms, the portion of the Software not modified must be distributed pursuant to this License. If anyone notifies Licensee in writing that Licensee has not complied with this Notice section, Licensee can keep this License by taking all practical steps to comply within 30 days after the notice. If Licensee does not do so, Licensee’s License (and all rights licensed hereunder) shall end immediately.
-
-* Compliance with Human Rights Principles and Human Rights Laws.
-
- 1. Human Rights Principles.
-
- (a) Licensee is advised to consult the articles of the United Nations Universal Declaration of Human Rights and the United Nations Global Compact that define recognized principles of international human rights (the “Human Rights Principles”). Licensee shall use the Software in a manner consistent with Human Rights Principles.
-
- (b) Unless the Licensor and Licensee agree otherwise, any dispute, controversy, or claim arising out of or relating to (i) Section 1(a) regarding Human Rights Principles, including the breach of Section 1(a), termination of this License for breach of the Human Rights Principles, or invalidity of Section 1(a) or (ii) a determination of whether any Law is consistent or in conflict with Human Rights Principles pursuant to Section 2, below, shall be settled by arbitration in accordance with the Hague Rules on Business and Human Rights Arbitration (the “Rules”); provided, however, that Licensee may elect not to participate in such arbitration, in which event this License (and all rights licensed hereunder) shall end immediately. The number of arbitrators shall be one unless the Rules require otherwise.
-
- Unless both the Licensor and Licensee agree to the contrary: (1) All documents and information concerning the arbitration shall be public and may be disclosed by any party; (2) The repository referred to under Article 43 of the Rules shall make available to the public in a timely manner all documents concerning the arbitration which are communicated to it, including all submissions of the parties, all evidence admitted into the record of the proceedings, all transcripts or other recordings of hearings and all orders, decisions and awards of the arbitral tribunal, subject only to the arbitral tribunal's powers to take such measures as may be necessary to safeguard the integrity of the arbitral process pursuant to Articles 18, 33, 41 and 42 of the Rules; and (3) Article 26(6) of the Rules shall not apply.
-
- 2. Human Rights Laws. The Software shall not be used by any person or entity for any systems, activities, or other uses that violate any Human Rights Laws. “Human Rights Laws” means any applicable laws, regulations, or rules (collectively, “Laws”) that protect human, civil, labor, privacy, political, environmental, security, economic, due process, or similar rights; provided, however, that such Laws are consistent and not in conflict with Human Rights Principles (a dispute over the consistency or a conflict between Laws and Human Rights Principles shall be determined by arbitration as stated above). Where the Human Rights Laws of more than one jurisdiction are applicable or in conflict with respect to the use of the Software, the Human Rights Laws that are most protective of the individuals or groups harmed shall apply.
-
- 3. Indemnity. Licensee shall hold harmless and indemnify Licensor (and any other contributor) against all losses, damages, liabilities, deficiencies, claims, actions, judgments, settlements, interest, awards, penalties, fines, costs, or expenses of whatever kind, including Licensor’s reasonable attorneys’ fees, arising out of or relating to Licensee’s use of the Software in violation of Human Rights Laws or Human Rights Principles.
-
-* Failure to Comply. Any failure of Licensee to act according to the terms and conditions of this License is both a breach of the License and an infringement of the intellectual property rights of the Licensor (subject to exceptions under Laws, e.g., fair use). In the event of a breach or infringement, the terms and conditions of this License may be enforced by Licensor under the Laws of any jurisdiction to which Licensee is subject. Licensee also agrees that the Licensor may enforce the terms and conditions of this License against Licensee through specific performance (or similar remedy under Laws) to the extent permitted by Laws. For clarity, except in the event of a breach of this License, infringement, or as otherwise stated in this License, Licensor may not terminate this License with Licensee.
-
-* Enforceability and Interpretation. If any term or provision of this License is determined to be invalid, illegal, or unenforceable by a court of competent jurisdiction, then such invalidity, illegality, or unenforceability shall not affect any other term or provision of this License or invalidate or render unenforceable such term or provision in any other jurisdiction; provided, however, subject to a court modification pursuant to the immediately following sentence, if any term or provision of this License pertaining to Human Rights Laws or Human Rights Principles is deemed invalid, illegal, or unenforceable against Licensee by a court of competent jurisdiction, all rights in the Software granted to Licensee shall be deemed null and void as between Licensor and Licensee. Upon a determination that any term or provision is invalid, illegal, or unenforceable, to the extent permitted by Laws, the court may modify this License to affect the original purpose that the Software be used in compliance with Human Rights Principles and Human Rights Laws as closely as possible. The language in this License shall be interpreted as to its fair meaning and not strictly for or against any party.
-
-* Disclaimer. TO THE FULL EXTENT ALLOWED BY LAW, THIS SOFTWARE COMES “AS IS,” WITHOUT ANY WARRANTY, EXPRESS OR IMPLIED, AND LICENSOR AND ANY OTHER CONTRIBUTOR SHALL NOT BE LIABLE TO ANYONE FOR ANY DAMAGES OR OTHER LIABILITY ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THIS LICENSE, UNDER ANY KIND OF LEGAL CLAIM.
-
-This Hippocratic License is an Ethical Source license (https://ethicalsource.dev) and is offered for use by licensors and licensees at their own risk, on an “AS IS” basis, and with no warranties express or implied, to the maximum extent permitted by Laws.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
index 4b908fd6f..f100d5b45 100644
--- a/README.md
+++ b/README.md
@@ -3,21 +3,93 @@
[](https://github.com/CAMBI-tech/BciPy/actions/workflows/main.yml)
[](https://app.codacy.com/gh/CAMBI-tech/BciPy/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
[](https://github.com/CAMBI-tech/BciPy/fork)
-[](https://twitter.com/cambi_tech)
+[](https://opensource.org/licenses/BSD-3-Clause)
-BciPy is a library for conducting Brain-Computer Interface experiments in Python. It functions as a standalone application for experimental data collection or you can take the tools you need and start coding your own system. See our official BciPy documentation including affiliations and more context information [here](https://bcipy.github.io/).
+[](https://cambi.tech)
-It will run on the latest windows (10, 11), linux (ubuntu 22.04) and macos (Sonoma). Other versions may work as well, but are not guaranteed. To see supported versions and operating systems as of this release see here: [BciPy Builds](https://github.com/CAMBI-tech/BciPy/actions/workflows/main.yml).
+BciPy is a library for conducting Brain-Computer Interface experiments in Python. It is designed to be modular and extensible, allowing researchers to easily add new paradigms, models, and processing methods. The focus of BciPy is on paradigms for communication and control, including Rapid Serial Visual Presentation (RSVP) and Matrix Speller. See our official documentation including affiliations and more context information [here](https://bcipy.github.io/).
-*Please cite us when using!*
+BciPy is released open-source under the BSD-3 clause. Please refer to [LICENSE.md](LICENSE.md).
+
+**If you use BciPy in your research, please cite the following manuscript:**
```text
Memmott, T., Koçanaoğulları, A., Lawhead, M., Klee, D., Dudy, S., Fried-Oken, M., & Oken, B. (2021). BciPy: brain–computer interface software in Python. Brain-Computer Interfaces, 1-18.
```
+## Table of Contents
+
+- [BciPy: Brain-Computer Interface Software in Python](#bcipy-brain-computer-interface-software-in-python)
+ - [Table of Contents](#table-of-contents)
+ - [Dependencies](#dependencies)
+ - [Linux](#linux)
+ - [Windows](#windows)
+ - [Mac](#mac)
+ - [Installation](#installation)
+ - [BciPy Setup](#bcipy-setup)
+ - [Editable Install and GUI usage](#editable-install-and-gui-usage)
+ - [PyPi Install](#pypi-install)
+ - [Make install](#make-install)
+ - [Usage](#usage)
+ - [Package Usage](#package-usage)
+ - [GUI Usage](#gui-usage)
+ - [Client Usage](#client-usage)
+ - [General Usage](#general-usage)
+ - [Running Experiments or Tasks via Command Line](#running-experiments-or-tasks-via-command-line)
+ - [Options](#options)
+ - [Train a Signal Model via Command Line](#train-a-signal-model-via-command-line)
+ - [Basic Commands Signal Model Training](#basic-commands-signal-model-training)
+ - [Visualize ERP data from a session with Target / Non-Target labels via Command Line](#visualize-erp-data-from-a-session-with-target--non-target-labels-via-command-line)
+ - [Basic Commands ERP Viz](#basic-commands-erp-viz)
+ - [BciPy Simulator](#bcipy-simulator)
+ - [Running the Simulator](#running-the-simulator)
+ - [Basic Commands Simulator](#basic-commands-simulator)
+ - [Other Options](#other-options)
+ - [Core Modules](#core-modules)
+ - [Top-Level Modules Overview](#top-level-modules-overview)
+ - [`Acquisition`](#acquisition)
+ - [`Core`](#core)
+ - [`Display`](#display)
+ - [`Feedback`](#feedback)
+ - [`GUI`](#gui)
+ - [`Helpers`](#helpers)
+ - [`IO`](#io)
+ - [`Language`](#language)
+ - [`Signal`](#signal)
+ - [`Simulator`](#simulator)
+ - [`Task`](#task)
+ - [Entry Point and Configuration Modules](#entry-point-and-configuration-modules)
+ - [`main.py`](#mainpy)
+ - [`parameters/`](#parameters)
+ - [`config.py`](#configpy)
+ - [`static/`](#static)
+ - [Paradigms](#paradigms)
+ - [RSVPKeyboard](#rsvpkeyboard)
+ - [Matrix Speller](#matrix-speller)
+ - [Offset Determination and Correction](#offset-determination-and-correction)
+ - [What is a Static Offset?](#what-is-a-static-offset)
+ - [How to Determine the Offset](#how-to-determine-the-offset)
+ - [Running Offset Determination](#running-offset-determination)
+ - [Applying the Offset Correction](#applying-the-offset-correction)
+ - [Using Make for Offset Determination](#using-make-for-offset-determination)
+ - [Additional Resources](#additional-resources)
+ - [Glossary](#glossary)
+ - [Scientific Publications using BciPy](#scientific-publications-using-bcipy)
+ - [2025](#2025)
+ - [2024](#2024)
+ - [2023](#2023)
+ - [2022](#2022)
+ - [2021](#2021)
+ - [2020](#2020)
+ - [Contributions Welcome](#contributions-welcome)
+ - [Contribution Guidelines](#contribution-guidelines)
+ - [Contributors](#contributors)
+
## Dependencies
-This project requires Python 3.8 or 3.9. Please see notes below for additional OS specific dependencies before installation can be completed and reference our documentation/FAQs for more information:
+This project requires Python 3.9, 3.10 or 3.11.
+
+It will run on the latest windows (10, 11), linux (ubuntu 22.04) and macos (Sonoma). Other versions may work as well, but are not guaranteed. To see supported versions and operating systems as of this release see our GitHub builds: [BciPy Builds](https://github.com/CAMBI-tech/BciPy/actions/workflows/main.yml). Please see notes below for additional OS specific dependencies before installation can be completed and reference our documentation here:
### Linux
@@ -27,183 +99,285 @@ You will need to install the prerequisites defined in `scripts\shell\linux_requi
If you are using a Windows machine, you will need to install the [Microsoft Visual C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/).
-
### Mac
-If you are using a Mac, you will need to install XCode and enable command line tools. `xcode-select --install`. If using an m1/2 chip, you will need to use the install script in `scripts/shell/m2chip_install.sh` to install the prerequisites. You may also need to use the Rosetta terminal to run the install script, but this has not been necessary in our testing using m2 chips.
+If you are using a Mac, you will need to install XCode and enable command line tools. `xcode-select --install`. If using an m1/2 chip, you may need to use the install script in `scripts/shell/m2chip_install.sh` to install the prerequisites. You may also need to use the Rosetta terminal to run the install script, but this has not been necessary in our testing using m2 chips.
-If using zsh, instead of bash, you may encounter a segementation fault when running BciPy. This is due to an issue in a dependeancy of psychopy with no known fix as of yet. Please use bash instead of zsh for now.
+If using zsh, instead of bash, you may encounter a segmentation fault when running BciPy. This is due to an issue in a dependency of psychopy with no known fix as of yet. Please use bash instead of zsh for now.
## Installation
### BciPy Setup
-In order to run BciPy on your computer, after following the dependencies above, you will need to install the BciPy package.
+In order to run BciPy on your computer, after ensuring the OS dependencies above are met, you can proceed to install the BciPy package.
+
+#### Editable Install and GUI usage
-To install for use locally and use of the GUI:
+If wanting to run the GUI or make changes to the code, you will need to install BciPy in editable mode. This will ensure that all dependencies are installed and the package is linked to your local directory. This will allow you to make changes to the code and see them reflected in your local installation without needing to reinstall the package.
1. Git clone .
2. Change directory in your terminal to the repo directory.
-3. [Optional] Install the kenlm language model package. `pip install kenlm==0.1 --global-option="--max_order=12"`.
-4. Install BciPy in development mode. `pip install -e .`
+3. Install BciPy in development mode.
-If wanting the latest version from PyPi and to build using modules:
+ ```sh
+ pip install -e .
+ ```
+
+#### PyPi Install
+
+If you do not want to run the GUI or make changes to the code, you can install BciPy from PyPi. This will install the package and all dependencies, but will not link it to your local directory. This means that any changes you make to the code will not be reflected in your local installation. This is the recommended installation method if wanting to use the modules without making changes to the BciPy code.
+
+```sh
+pip install bcipy
+```
-1. `pip install bcipy`
+#### Make install
Alternately, if [Make](http://www.mingw.org/) is installed, you may run the follow command to install:
```sh
-# install in development mode
+# install in development mode with all testing and demo dependencies
make dev-install
```
+## Usage
+
+The BciPy package may be used in two ways: via the command line interface (CLI) or via the graphical user interface (GUI). The CLI is useful for running experiments, training models, and visualizing data without needing to run the GUI. The GUI is useful for running experiments, editing parameters and training models with a more user-friendly interface.
+
+### Package Usage
+
+To run the package, you will need to import the modules you want to use. For example, to run the the system info module, you can run the following:
+
+```python
+from bcipy.helpers import system_utils
+system_utils.get_system_info()
+```
+
+### GUI Usage
+
+Run the following command in your terminal to start the BciPy GUI:
+
+```sh
+python bcipy/gui/BCInterface.py
+```
+
+Alternately, if Make is installed, you may run the follow command to start the GUI from the BciPy root directory:
+
+```sh
+make bci-gui
+```
+
### Client Usage
-#### Run an experiment protocol or task
+Once BciPy is installed, it can be used via the command line interface. This is useful for running experiments, training models, and visualizing data without needing to run the GUI.
-Invoke an experiment protocol or task directly using command line utility `bcipy`.
+#### General Usage
-Use the help flag to see other available input options: `bcipy --help`
+Use the help flag to explore all available options:
-You can pass it attributes with flags, if desired.
+```sh
+bcipy --help
+```
-- Run with a User ID and Task:
- - `bcipy --user "bci_user" --task "RSVP Calibration"`
-- Run with a User ID and Tasks with a registered Protocol:
- - `bcipy --user "bci_user" --experiment "default"`
-- Run with fake data:
- - `bcipy --fake`
-- Run without visualizations:
- - `bcipy --noviz`
-- Run with alerts after each Task execution:
- - `bcipy --alert`
-- Run with custom parameters:
- - `bcipy --parameters "path/to/valid/parameters.json"`
-
-#### Train a signal model with registered BciPy models
+#### Running Experiments or Tasks via Command Line
-To train a signal model (currently `PCARDAKDE` and `GazeModels`), run the following command after installing BciPy:
+You can invoke an experiment protocol or task directly using the `bcipy` command-line utility. This allows for flexible execution of tasks with various configurations.
-`bcipy-train`
+##### Options
-Use the help flag to see other available input options: `bcipy-train --help`
-
-You can pass it attributes with flags, if desired.
-
-- Run without a window prompting for data session folder:
- - `bcipy-train -d path/to/data`
-- Run with data visualizations (ERPs, etc.):
- - `bcipy-train -v`
-- Run with data visualizations that do not show, but save to file:
- - `bcipy-train -s`
-- Run with balanced accuracy:
- - `bcipy-train --balanced-acc`
-- Run with alerts after each Task execution:
- - `bcipy-train --alert`
-- Run with custom parameters:
- - `bcipy-train -p "path/to/valid/parameters.json"`
+```sh
+# Run with a User ID and Task
+bcipy --user "bci_user" --task "RSVP Calibration"
+
+# Run with a User ID and Experiment Protocol
+bcipy --user "bci_user" --experiment "default"
+
+# Run with Simulated Data
+bcipy --fake
+
+# Run without Visualizations
+bcipy --noviz
+
+# Run with Alerts after Task Execution
+bcipy --alert
+
+# Run with Custom Parameters
+bcipy --parameters "path/to/valid/parameters.json"
+```
+
+These options provide flexibility for running experiments tailored to your specific needs.
-#### Visualize ERP data from a session with Target / Non-Target labels
+#### Train a Signal Model via Command Line
-To generate plots that can be shown or saved after collection of data, run the following command after installing BciPy:
+To train a signal model (e.g., `PCARDAKDE` or `GazeModels`), use the `bcipy-train` command.
-`bcipy-erp-viz`
+##### Basic Commands Signal Model Training
-Use the help flag to see other available input options: `bcipy-erp-viz --help`
+```sh
+# Display help information
+bcipy-train --help
-You can pass it attributes with flags, if desired.
+# Train using data from a specific folder
+bcipy-train -d path/to/data
-- Run without a window prompt for a data session folder:
- - `bcipy-erp-viz -s path/to/data`
-- Run with data visualizations (ERPs, etc.):
- - `bcipy-erp-viz --show`
-- Run with data visualizations that do not show, but save to file:
- - `bcipy-erp-viz --save`
-- Run with custom parameters (default is in bcipy/parameters/parameters.json):
- - `bcipy-erp-viz -p "path/to/valid/parameters.json"`
+# Display data visualizations (e.g., ERPs)
+bcipy-train -v
-#### BciPy Simulator Usage
+# Save visualizations to a file without displaying them
+bcipy-train -s
-The simulator can be run using the command line utility `bcipy-sim`.
+# Train with balanced accuracy metrics
+bcipy-train --balanced-acc
-To run the simulator with a single data folder, a custom parameters file, a trained model, and 5 iterations, use the following command:
+# Receive alerts after each task execution
+bcipy-train --alert
-`bcipy-sim -d my_data_folder/ -p my_parameters.json -m my_models/ -n 5`
+# Use a custom parameters file
+bcipy-train -p path/to/parameters.json
+```
+
+#### Visualize ERP data from a session with Target / Non-Target labels via Command Line
-For more information or to see other available input options, use the help flag: `bcipy-sim --help`. In addition, more information can be found in the simulator module README.
+To visualize ERP data from a session with Target / Non-Target labels, use the `bcipy-erp-viz` command. This command allows you to visualize the data collected during a session and provides options for saving or displaying the visualizations.
-### Package Usage
+##### Basic Commands ERP Viz
-```python
-from bcipy.helpers import system_utils
-system_utils.get_system_info()
+```sh
+# Display help information
+bcipy-erp-viz --help
+
+# Run without a window prompt for a data session folder
+bcipy-erp-viz -s path/to/data
+
+# Run with data visualizations (ERPs, etc.)
+bcipy-erp-viz --show
+
+# Run with data visualizations that do not show, but save to file
+bcipy-erp-viz --save
+
+# Run with custom parameters (default is in bcipy/parameters/parameters.json)
+bcipy-erp-viz -p "path/to/valid/parameters.json"
```
-### GUI Usage
+### BciPy Simulator
-Run the following command in your terminal to start the BciPy GUI:
+The BciPy simulator allows you to run simulations based on previously collected data. This is useful for testing and validating models and algorithms without needing to collect new data.
+
+#### Running the Simulator
+
+The simulator can be executed using the `bcipy-sim` command-line utility.
+
+##### Basic Commands Simulator
```sh
-python bcipy/gui/BCInterface.py
+bcipy-sim --help
```
-Alternately, if Make is installed, you may run the follow command to start the GUI from the BciPy root directory:
+##### Other Options
+
+- `-d`: Path to the data folder.
+- `-p`: Path to the custom parameters file. [optional]
+- `-m`: Path to the directory of trained model pickle files.
+- `-n`: Number of iterations to run.
```sh
-make bci-gui
+bcipy-sim -d path/to/data -p path/to/parameters.json -m path/to/model.pkl/ -n 5
```
-## Glossary
+More comprehensive information can be found in the [Simulator Module README](./bcipy/simulator/README.md).
-***Stimuli***: A single letter, tone or image shown (generally in an inquiry). Singular = stimulus, plural = stimuli.
+## Core Modules
-***Trial***: A collection of data after a stimuli is shown. A----
+### Top-Level Modules Overview
-***Inquiry***: The set of stimuli after a fixation cross in a spelling task to gather user intent. A ---- B --- C ----
+Each module includes its own README, demo, and tests. Click on the module name to view its README for more information.
-***Series***: Each series contains at least one inquiry. A letter/icon decision is made after a series in a spelling task.
+#### [`Acquisition`](./bcipy/acquisition/README.md)
-***Session***: Data collected for a task. Comprised of metadata about the task and a list of Series.
+Captures data, returns desired time series, and saves to file at the end of a session.
-***Protocol***: A collection of tasks and actions to be executed in a session. This is defined as within experiments and can be registered using the BciPy GUI.
+#### [`Core`](./bcipy/core/README.md)
-***Task***: An experimental design with stimuli, trials, inquiries and series for use in BCI. For instance, "RSVP Calibration" is a task.
+Core data structures and methods essential for BciPy operation.
-***Mode***: Common design elements between task types. For instance, Calibration and Free Spelling are modes.
+- Includes triggers, parameters, and raw data handling.
-***Paradigm***: Display paradigm with unique properties and modes. Ex. Rapid-Serial Visual Presentation (RSVP), Matrix Speller, Steady-State Visual Evoked Potential (SSVEP).
+#### [`Display`](./bcipy/display/README.md)
-## Core Modules
+Manages the display of stimuli on the screen and records stimuli timing.
+
+#### [`Feedback`](./bcipy/feedback/README.md)
+
+Provides feedback mechanisms for sound and visual stimuli.
+
+#### [`GUI`](./bcipy/gui/README.md)
+
+End-user interface for registered BCI tasks and parameter editing.
+
+- Key files: [BCInterface.py](./bcipy/gui/BCInterface.py) and [ParamsForm](./bcipy/gui/parameters/params_form.py).
+
+#### [`Helpers`](./bcipy/helpers/README.md)
+
+Utility functions for interactions between modules and general-purpose tasks.
+
+#### [`IO`](./bcipy/io/README.md)
+
+Handles data file operations such as loading, saving, and format conversion.
+
+- Supported formats: BIDS, BrainVision, EDF, MNE, CSV, JSON, etc.
+
+#### [`Language`](./bcipy/language/README.md)
+
+Provides symbol probability predictions during typing tasks.
+
+#### [`Signal`](./bcipy/signal/README.md)
+
+Includes EEG signal models, gaze signal models, filters, processing tools, evaluators, and viewers.
+
+#### [`Simulator`](./bcipy/simulator/README.md)
-This a list of the major modules and their functionality. Each module will contain its own README, demo and tests. Please check them out for more information!
-
-- `acquisition`: acquires data, gives back desired time series, saves to file at end of session.
-- `config`: configuration parameters for the application, including paths and data filenames.
-- `core`: core data structures and methods needed for BciPy operation. These include triggers, parameters, and raw data.
-- `display`: handles display of stimuli on screen and passes back stimuli timing.
-- `feedback`: feedback mechanisms for sound and visual stimuli.
-- `gui`: end-user interface into registered bci tasks and parameter editing. See BCInterface.py.
-- `helpers`: helpful functions needed for interactions between modules and general utility.
-- `io`: load, save, and convert data files. Ex. BIDS, BrainVision, EDF, MNE, CSV, JSON, etc.
-- `language`: gives probabilities of next symbols during typing.
-- `main`: executor of experiments. Main entry point into the application
-- `parameters`: location of json parameters. This includes parameters.json (main experiment / app configuration) and device.json (device registry and configuration).
-- `signal`: eeg signal models, gaze signal models, filters, processing, evaluators and viewers.
-- `simulator`: provides support for running simulations based off of previously collected data.
-- `static`: image and sound stimuli, misc manuals, and readable texts for gui.
-- `task`: bcipy implemented user tasks. Main collection of bci modules for use during various experimentation. Ex. RSVP Calibration.
+Supports running simulations based on previously collected data.
+
+#### [`Task`](./bcipy/task/README.md)
+
+Implements user tasks and actions for BCI experiments.
+
+- Examples: RSVP Calibration, InterTaskAction.
+
+### Entry Point and Configuration Modules
+
+#### [`main.py`](./bcipy/main.py)
+
+The main executor of experiments and the primary entry point into the application. See the [Running Experiments](#running-experiments-or-tasks-via-command-line) section for more information.
+
+#### [`parameters/`](./bcipy/parameters/)
+
+Contains JSON configuration files:
+
+- [`parameters.json`](./bcipy/parameters/parameters.json): Main experiment and application configuration.
+- [`device.json`](./bcipy/parameters/device.json): Device registry and configuration.
+- [`experiments.json`](./bcipy/parameters/experiment/experiments.json): Experiment / protocol registry and configuration.
+- [`phrases.json`](./bcipy/parameters/experiment/phrases.json): Phrase registry and configuration. This can be used to define a list of phrases used in the RSVP and Matrix Speller Copy phrase tasks. If not defined in parameters.json, the `task_text` parameter will be used.
+
+#### [`config.py`](./bcipy/config.py)
+
+Holds configuration parameters for BciPy, including paths and default data filenames.
+
+#### [`static/`](./bcipy/static/)
+
+Includes resources such as:
+
+- Image and sound stimuli.
+- Miscellaneous manuals and readable texts for the GUI.
## Paradigms
-See `bcipy/task/README.md` for more information on all supported paradigms, tasks, actions and modes. The following are the supported and validated paradigms:
+See the [Task README](./bcipy/task/README.md) for more information on all supported paradigms, tasks, actions and modes. The major paradigms are listed below.
### RSVPKeyboard
*RSVP KeyboardTM* is an EEG (electroencephalography) based BCI (brain computer interface) typing system. It utilizes a visual presentation technique called rapid serial visual presentation (RSVP). In RSVP, the options are presented rapidly at a single location with a temporal separation. Similarly in RSVP KeyboardTM, the symbols (the letters and additional symbols) are shown at the center of screen. When the subject wants to select a symbol, they await the intended symbol during the presentation and elicit a p300 response to a target symbol.
-References:
-
```text
Orhan, U., Hild, K. E., 2nd, Erdogmus, D., Roark, B., Oken, B., & Fried-Oken, M. (2012). RSVP Keyboard: An EEG Based Typing Interface. Proceedings of the ... IEEE International Conference on Acoustics, Speech, and Signal Processing. ICASSP (Conference), 10.1109/ICASSP.2012.6287966. https://doi.org/10.1109/ICASSP.2012.6287966
```
@@ -212,152 +386,121 @@ Orhan, U., Hild, K. E., 2nd, Erdogmus, D., Roark, B., Oken, B., & Fried-Oken, M.
Matrix Speller is an EEG (electroencephalography) based BCI (brain computer interface) typing system. It utilizes a visual presentation technique called Single Character Presentation (SCP). In matrix speller, the symbols are arranged in a matrix with fixed number of rows and columns. Using SCP, subsets of these symbols are intensified (i.e. highlighted) usually in pseudorandom order to produce an odd ball paradigm to induce p300 responses.
-References:
-
```text
Farwell, L. A., & Donchin, E. (1988). Talking off the top of your head: toward a mental prosthesis utilizing event-related brain potentials. Electroencephalography and clinical Neurophysiology, 70(6), 510-523.
Ahani A, Moghadamfalahi M, Erdogmus D. Language-Model Assisted And Icon-based Communication Through a Brain Computer Interface With Different Presentation Paradigms. IEEE Trans Neural Syst Rehabil Eng. 2018 Jul 25. doi: 10.1109/TNSRE.2018.2859432.
```
-## Demo
-
-All major functions and modules have demo and test files associated with them which may be run locally. This should help orient you to the functionality as well as serve as documentation. *If you add to the repo, you should be adding tests and fixing any test that fail when you change the code.*
-
-For example, you may run the main BciPy demo by:
-
-`python demo/bci_main_demo.py`
+## Offset Determination and Correction
-This demo will load in parameters and execute a demo task defined in the file. There are demo files contained in most modules, excepting gui, signal and parameters. Run them as a python script!
+> [!CAUTION]
+> Static offset determination and correction are critical steps before starting an experiment. BciPy uses LSL to acquire EEG data and Psychopy to present stimuli. The synchronization between the two systems is crucial for accurate data collection and analysis.
-## Offset Determination and Correction
+### What is a Static Offset?
-Static offset determination and correction are critical steps before starting an experiment. BciPy uses LSL to acquire EEG data and Psychopy to present stimuli. The synchronization between the two systems is crucial for accurate data collection and analysis.
+A static offset is the regular time difference between signals and stimuli presentation. This offset is determined through testing using a photodiode or another triggering mechanism. Once determined, the offset is corrected by shifting the EEG signal using the `static_offset` parameter in devices.json.
-[LSL synchronization documentation](https://labstreaminglayer.readthedocs.io/info/time_synchronization.html)
-[PsychoPy timing documentation](https://www.psychopy.org/general/timing/index.html)
+#### How to Determine the Offset
-A static offset is the regular time difference between our signals and stimuli. This offset is determined through testing via a photodiode or other triggering mechanism. The offset correction is done by shifting the EEG signal by the determined offset using the `static_offset` parameter.
+To determine the static offset, you can run a timing verification task (e.g., `RSVPTimingVerification`) with a photodiode attached to the display and connected to your device. After collecting the data, use the `offset` module to analyze the results and recommend an offset correction value.
-After running a timing verification task (such as, RSVPTimingVerification) with a photodiode attached to the display and connected to a device, the offset can be determined by analyzing the data. Use the `offset` module to recommend an offset correction value and display the results.
+#### Running Offset Determination
-To run the offset determination and print the results, use the following command:
+To calculate the offset and display the results, use the following command:
```bash
python bcipy/helpers/offset.py -r
```
-After running the above command, the recommended offset correction value will be displayed in the terminal and can be passed to determine system stability and display the results.
+This will analyze the data and recommend an offset correction value, which will be displayed in the terminal.
+
+#### Applying the Offset Correction
+
+Once you have the recommended offset value, you can apply it to verify system stability and display the results. For example, if the recommended offset value is `0.1`, run the following command:
```bash
-# Let's say the recommneded offset value is 0.1
python bcipy/helpers/offset.py --offset "0.1" -p
```
-Alternately, if Make is installed, you may run the follow command to run offset determination and display the results:
+#### Using Make for Offset Determination
+
+If `Make` is installed, you can simplify the process by running the following command to determine the offset and display the results:
```sh
make offset-recommend
```
-## Testing
-
-When writing tests, put them in the correct module, in a tests folder, and prefix the file and test itself with `test_` in order for pytest to discover it. See other module tests for examples!
-
-Development requirements must be installed before running: `pip install -r dev_requirements.txt`
-
-To run all tests, in the command line:
+#### Additional Resources
-```python
-py.test
-```
+For more information on synchronization and timing, refer to the following documentation:
-To run a single modules tests (ex. acquisition), in the command line:
+- [LSL Synchronization Documentation](https://labstreaminglayer.readthedocs.io/info/time_synchronization.html)
+- [PsychoPy Timing Documentation](https://www.psychopy.org/general/timing/index.html)
-```python
-py.test acquisition
-```
+## Glossary
-To generate test coverage metrics, in the command line:
+***Stimuli***: A single letter, tone or image shown (generally in an inquiry). Singular = stimulus, plural = stimuli.
-```bash
-coverage run --branch --source=bcipy -m pytest --mpl -k "not slow"
+***Trial***: A collection of data after a stimuli is shown. A----
-#Generate a command line report
-coverage report
+***Inquiry***: The set of stimuli after a fixation cross in a spelling task to gather user intent. A ---- B --- C ----
-# Generate html doc in the bci folder. Navigate to index.html and click.
-coverage html
-```
+***Series***: Each series contains at least one inquiry. A letter/icon decision is made after a series in a spelling task.
-Alternately, if Make is installed, you may run the follow command to run coverage/pytest and generate the html:
+***Session***: Data collected for a task. Comprised of metadata about the task and a list of Series.
-```sh
-make coverage-html
-```
+***Protocol***: A collection of tasks and actions to be executed in a session. This is defined for each experiment and can be registered in experiments.json via the BCI GUI.
-## Linting
+***Experiment***: A protocol with a set of parameters. This is defined within experiments and can be registered in experiments.json via the BCI GUI.
-This project enforces `PEP` style guidelines using [flake8](http://flake8.pycqa.org/en/latest/).
+***Task***: An experimental design with stimuli, trials, inquiries and series for use in BCI. For instance, "RSVP Calibration" is a task.
-To avoid spending unnecessary time on formatting, we recommend using `autopep8`. You can specify a file or directory to auto format. When ready to push your code, you may run the following commands to format your code:
+***Action***: A task without a paradigm. For instance, "RSVP Calibration" is a task, but "InterTaskAction" is an action. These are most often used to define the actions that take place in between tasks.
-```sh
-# autoformat all files in bcipy
-autopep8 --in-place --aggressive -r bcipy
+***Mode***: Common design elements between task types. For instance, Calibration and Free Spelling are modes.
-# autoformat only the processor file
-autopep8 --in-place --aggressive bcipy/acquisition/processor.py
-```
+***Paradigm***: Display paradigm with unique properties and modes. Ex. Rapid-Serial Visual Presentation (RSVP), Matrix Speller, Steady-State Visual Evoked Potential (SSVEP).
-Finally, run the lint check: `flake8 bcipy`.
+## Scientific Publications using BciPy
-Alternately, if Make is installed, you may run the follow command to run autopep8 and flake8:
+### 2025
-```sh
-make lint
-```
+- Memmott, T., Klee, D., Smedemark-Margulies, N., & Oken, B. (2025). Artifact filtering application to increase online parity in a communication BCI: progress toward use in daily-life. Frontiers in Human Neuroscience, 19, 1551214.
+- Peters, B., Celik, B., Gaines, D., Galvin-McLaughlin, D., Imbiriba, T., Kinsella, M., ... & Fried-Oken, M. (2025). RSVP keyboard with inquiry preview: mixed performance and user experience with an adaptive, multimodal typing interface combining EEG and switch input. Journal of neural engineering, 22(1), 016022.
-## Type Checking
+### 2024
-This project enforces `mypy` type checking. The typing project configuration is found in the mypy.ini file. To run type checking, run the following command:
+- Klee, D., Memmott, T., & Oken, B. (2024). The Effect of Jittered Stimulus Onset Interval on Electrophysiological Markers of Attention in a Brain–Computer Interface Rapid Serial Visual Presentation Paradigm. Signals, 5(1), 18-39.
+- Kocanaogullari, D. (2024). Detection and Assessment of Spatial Neglect Using a Novel Augmented Reality-Guided Eeg-Based Brain-Computer Interface (Doctoral dissertation, University of Pittsburgh).
+- Smedemark-Margulies, N. (2024). Reducing Calibration Effort for Brain-Computer Interfaces (Doctoral dissertation, Northeastern University).
-```sh
-mypy bcipy
-```
+### 2023
-To generate a report, run the following command:
+- Smedemark-Margulies, N., Celik, B., Imbiriba, T., Kocanaogullari, A., & Erdoğmuş, D. (2023, June). Recursive estimation of user intent from noninvasive electroencephalography using discriminative models. In ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (pp. 1-5). IEEE.
-```sh
-mypy --html-report bcipy
-```
+### 2022
-Alternately, if Make is installed, you may run the follow command to run mypy:
+- Mak, J., Kocanaogullari, D., Huang, X., Kersey, J., Shih, M., Grattan, E. S., ... & Akcakaya, M. (2022). Detection of stroke-induced visual neglect and target response prediction using augmented reality and electroencephalography. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 30, 1840-1850.
+- Galvin-McLaughlin, D., Klee, D., Memmott, T., Peters, B., Wiedrick, J., Fried-Oken, M., ... & Dudy, S. (2022). Methodology and preliminary data on feasibility of a neurofeedback protocol to improve visual attention to letters in mild Alzheimer's disease. Contemporary Clinical Trials Communications, 28, 100950.
+- Klee, D., Memmott, T., Smedemark-Margulies, N., Celik, B., Erdogmus, D., & Oken, B. S. (2022). Target-related alpha attenuation in a brain-computer interface rapid serial visual presentation calibration. Frontiers in Human Neuroscience, 16, 882557.
-```sh
-make type
-```
+### 2021
-### Contributions Welcome
+- Koçanaoğulları, A., Akcakaya, M., & Erdoğmuş, D. (2021). Stopping criterion design for recursive Bayesian classification: analysis and decision geometry. IEEE Transactions on Pattern Analysis and Machine Intelligence, 44(9), 5590-5601.
-If you want to be added to the development team slack or have additional questions, please reach out to us at !
+### 2020
-#### Contribution Guidelines
+- Koçanaogullari, A. (2020). Active Recursive Bayesian Classification (Querying and Stopping) for Event Related Potential Driven Brain Computer Interface Systems (Doctoral dissertation, Northeastern University).
+- Koçanaoğulları, A., Akçakaya, M., Oken, B., & Erdoğmuş, D. (2020, June). Optimal modality selection using information transfer rate for event related potential driven brain computer interfaces. In Proceedings of the 13th ACM International Conference on PErvasive Technologies Related to Assistive Environments (pp. 1-7).
-We follow and will enforce the contributor's covenant to foster a safe and inclusive environment for this open source software, please reference this link for more information:
+## Contributions Welcome
-We welcome all contributions to BciPy! Please follow the guidelines below:
+If you want to be added to the development team Discord or have additional questions, please reach out to us at !
-- All modules require tests and a demo.
-- All tests must pass to merge, even if they are seemingly unrelated to your work.
-- Use Spaces, not Tabs.
-- Use informative names for functions and classes.
-- Document the input and output of your functions / classes in the code. eg in-line commenting and typing.
-- Do not push IDE or other local configuration files.
-- All new modules or major functionality should be documented outside of the code with a README.md.
--- See README.md in repo or go to this site for inspiration: . Always use a Markdown interpreter before pushing.
+### Contribution Guidelines
-See this resource for examples:
+We follow and will enforce the code of conduct outlined [here](CODE_OF_CONDUCT.md). Please read it before contributing.
### Contributors
diff --git a/bcipy/acquisition/datastream/generator.py b/bcipy/acquisition/datastream/generator.py
index cb2dfe04c..79c31ebd3 100644
--- a/bcipy/acquisition/datastream/generator.py
+++ b/bcipy/acquisition/datastream/generator.py
@@ -1,6 +1,6 @@
"""Functions for generating mock data to be used for testing/development."""
-from typing import Callable, Generator, Optional
+from typing import Any, Callable, Generator, List, Optional, TextIO
from past.builtins import range
@@ -8,8 +8,13 @@
from bcipy.signal.generator.generator import gen_random_data
-def advance_to_row(filehandle, rownum):
- """Utility function to advance a file cursor to the given row."""
+def advance_to_row(filehandle: TextIO, rownum: int) -> None:
+ """Utility function to advance a file cursor to the given row.
+
+ Args:
+ filehandle (TextIO): The file handle to advance.
+ rownum (int): The target row number to advance to (1-indexed).
+ """
for _ in range(rownum - 1):
filehandle.readline()
@@ -21,39 +26,52 @@ class _DefaultEncoder:
"""Encodes data by returning the raw data."""
# pylint: disable=no-self-use
- def encode(self, data):
- """Encode the data that will be output by the file_data generator."""
+ def encode(self, data: Any) -> Any:
+ """Encodes the input data.
+
+ Args:
+ data (Any): The data to be encoded.
+
+ Returns:
+ Any: The raw input data, as no actual encoding is performed.
+ """
return data
-def generator_with_args(generator_fn, **generator_args) -> Callable[[], Generator]:
+def generator_with_args(generator_fn: Callable[..., Generator], **generator_args: Any) -> Callable[..., Generator]:
"""Constructs a generator with the given arguments.
- Parameters
- ----------
- generator_fn : Function
- a generator function
-
- Returns
- -------
- Function which creates a generator using the given args.
+
+ Args:
+ generator_fn (Callable[..., Generator]): A generator function.
+ **generator_args (Any): Keyword arguments to be passed to the generator_fn.
+
+ Returns:
+ Callable[..., Generator]: A function that creates a generator using the given args.
"""
- def factory(**args):
+ def factory(**args: Any) -> Generator:
return generator_fn(**{**generator_args, **args})
return factory
-def random_data_generator(encoder=_DefaultEncoder(),
- low=-1000,
- high=1000,
- channel_count=25):
+def random_data_generator(encoder: _DefaultEncoder = _DefaultEncoder(),
+ low: int = -1000,
+ high: int = 1000,
+ channel_count: int = 25) -> Generator[Any, None, None]:
"""Generator that outputs random EEG-like data encoded according to the provided encoder.
- Returns
- -------
- A generator that produces packet of data, which decodes into a list of
- floats in the range low to high with channel_count number of items.
+ Args:
+ encoder (_DefaultEncoder, optional): An encoder object to encode the output data.
+ Defaults to _DefaultEncoder().
+ low (int, optional): The lower bound for random data generation. Defaults to -1000.
+ high (int, optional): The upper bound for random data generation. Defaults to 1000.
+ channel_count (int, optional): The number of channels (items) in the generated data.
+ Defaults to 25.
+
+ Yields:
+ Any: A packet of data, which decodes into a list of floats in the range
+ low to high with `channel_count` number of items, encoded by the encoder.
"""
while True:
@@ -61,23 +79,29 @@ def random_data_generator(encoder=_DefaultEncoder(),
yield encoder.encode(sensor_data)
-def file_data_generator(filename, header_row=3, encoder=_DefaultEncoder(), channel_count: Optional[int] = None):
+def file_data_generator(filename: str,
+ header_row: int = 3,
+ encoder: _DefaultEncoder = _DefaultEncoder(),
+ channel_count: Optional[int] = None) -> Generator[List[float], None, None]:
"""Generates data from a source file and encodes it according to the
provided encoder.
- Parameters
- ----------
- filename: str
- Name of file containing a sample EEG session output. This file will
- be the source of the generated data. File should be a csv file.
- header_row: int, optional
- Row with the header data (channel names); the default is 3,
- assuming the first 2 rows are for metadata.
- encoder : Encoder, optional
- Used to encode the output data.
- channel_count : int, optional
- If provided this is used to truncate the data to the given number
- of channels.
+ Args:
+ filename (str): Name of file containing a sample EEG session output.
+ This file will be the source of the generated data.
+ File should be a csv file.
+ header_row (int, optional): Row with the header data (channel names).
+ The default is 3, assuming the first 2 rows
+ are for metadata.
+ encoder (_DefaultEncoder, optional): Used to encode the output data.
+ Defaults to _DefaultEncoder().
+ channel_count (Optional[int], optional): If provided this is used to truncate
+ the data to the given number of channels.
+ Defaults to None.
+
+ Yields:
+ List[float]: A list of float values representing sensor data from the file,
+ encoded by the encoder.
"""
with open(filename, 'r', encoding=DEFAULT_ENCODING) as infile:
@@ -95,8 +119,17 @@ def file_data_generator(filename, header_row=3, encoder=_DefaultEncoder(), chann
def data_value(value: str) -> float:
- """Convert to a float; some trigger values are strings, rather than
- numbers (ex. indicating the letter); convert these to 1.0."""
+ """Converts a string value to a float.
+
+ Some trigger values might be strings (e.g., indicating a letter),
+ which are converted to 1.0. Empty strings are converted to 0.0.
+
+ Args:
+ value (str): The string value to convert.
+
+ Returns:
+ float: The converted float value.
+ """
if value:
try:
return float(value)
diff --git a/bcipy/acquisition/datastream/lsl_server.py b/bcipy/acquisition/datastream/lsl_server.py
index cee201f9f..71f6f5bee 100644
--- a/bcipy/acquisition/datastream/lsl_server.py
+++ b/bcipy/acquisition/datastream/lsl_server.py
@@ -1,11 +1,11 @@
# mypy: disable-error-code="misc"
-"""Data server that streams EEG data over a LabStreamingLayer StreamOutlet
-using pylsl."""
+"""Data server that streams EEG data over a LabStreamingLayer StreamOutlet using pylsl."""
+
import logging
import time
import uuid
from queue import Empty, Queue
-from typing import Generator, Optional
+from typing import Generator, List, Optional, Tuple
from pylsl import StreamInfo, StreamOutlet
@@ -29,22 +29,21 @@ class LslDataServer(StoppableThread):
fake_data can be set to false and this module can be run standalone in its
own python instance.
- Parameters
- ----------
- device_spec : DeviceSpec
- parameters used to configure the server. Should have at least
- a list of channels and the sample frequency.
- generator : optional Generator (see generator.py for options)
- used to generate the data to be served. Uses random_data_generator
- by default.
- include_meta: bool, optional
- if True, writes metadata to the outlet stream.
- add_markers: bool, optional
- if True, creates a the marker channel and streams data to this
- channel at a fixed frequency.
- marker_stream_name: str, optional
- name of the sample marker stream
- chunk_size: int, optional chunk size.
+ Args:
+ device_spec (DeviceSpec): Parameters used to configure the server.
+ Should have at least a list of channels and the
+ sample frequency.
+ generator (Optional[Generator], optional): Used to generate the data to be
+ served. Uses `random_data_generator`
+ by default. Defaults to None.
+ include_meta (bool, optional): If True, writes metadata to the outlet stream.
+ Defaults to True.
+ add_markers (bool, optional): If True, creates a the marker channel and
+ streams data to this channel at a fixed
+ frequency. Defaults to False.
+ marker_stream_name (str, optional): Name of the sample marker stream.
+ Defaults to `MARKER_STREAM_NAME`.
+ chunk_size (int, optional): Chunk size for the LSL stream. Defaults to 0.
"""
def __init__(self,
@@ -57,7 +56,8 @@ def __init__(self,
super(LslDataServer, self).__init__()
self.device_spec = device_spec
- self.generator = generator or random_data_generator(channel_count=device_spec.channel_count)
+ self.generator = generator or random_data_generator(
+ channel_count=device_spec.channel_count)
log.debug("Starting LSL server for device: %s", device_spec.name)
print(f"Serving: {device_spec}")
@@ -96,8 +96,12 @@ def __init__(self,
self.markers_outlet = StreamOutlet(markers_info)
self.started = False
- def stop(self):
- """Stop the thread and cleanup resources."""
+ def stop(self) -> None:
+ """Stops the thread and cleans up resources.
+
+ This method sets the internal flag to stop the thread's execution
+ and closes the LSL `StreamOutlet`s, making them no longer discoverable.
+ """
log.debug("[*] Stopping data server")
super(LslDataServer, self).stop()
@@ -111,10 +115,12 @@ def stop(self):
del self.markers_outlet
self.markers_outlet = None
- def run(self):
- """Main loop of the thread. Continuously streams data to the stream
- outlet at a rate consistent with the sample frequency. May also
- output markers at a different interval."""
+ def run(self) -> None:
+ """Main loop of the thread.
+
+ Continuously streams data to the stream outlet at a rate consistent
+ with the sample frequency. May also output markers at a different interval.
+ """
sample_counter = 0
self.started = True
@@ -138,8 +144,16 @@ def run(self):
log.debug("[*] No longer pushing data")
-def _settings(filename):
- """Read the daq settings from the given data file"""
+def _settings(filename: str) -> Tuple[str, int, List[str]]:
+ """Reads the DAQ settings from the given data file.
+
+ Args:
+ filename (str): The path to the data file containing DAQ settings.
+
+ Returns:
+ Tuple[str, int, List[str]]: A tuple containing the DAQ type (str),
+ sample rate (int), and a list of channel names (List[str]).
+ """
with open(filename, 'r', encoding=DEFAULT_ENCODING) as datafile:
daq_type = datafile.readline().strip().split(',')[1]
@@ -148,15 +162,20 @@ def _settings(filename):
return (daq_type, sample_hz, channels)
-def await_start(dataserver: LslDataServer, max_wait: float = 2.0):
- """Blocks until server is started. Raises if max_wait is exceeded before
- server is started.
+def await_start(dataserver: LslDataServer, max_wait: float = 2.0) -> None:
+ """Blocks until the LSL data server is started.
+
+ Raises an exception if `max_wait` is exceeded before the server starts.
+
+ Args:
+ dataserver (LslDataServer): An instantiated (unstarted) server on which to wait.
+ max_wait (float, optional): The maximum number of seconds to wait.
+ Defaults to 2.0. After this period, if the
+ server has not successfully started, an
+ exception is thrown.
- Parameters
- ----------
- dataserver - instantiated (unstarted) server on which to wait.
- max_wait - the max number of seconds to wait. After this period if the
- server has not succesfully started an exception is thrown.
+ Raises:
+ Exception: If the server cannot start up within the `max_wait` period.
"""
dataserver.start()
@@ -170,8 +189,14 @@ def await_start(dataserver: LslDataServer, max_wait: float = 2.0):
raise Exception("Server couldn't start up in time.")
-def main():
- """Initialize and start the server."""
+def main() -> None:
+ """Initializes and starts the LSL data server.
+
+ This function parses command-line arguments to configure the server
+ (e.g., source filename, marker inclusion, device name, chunk size).
+ It then creates and starts an `LslDataServer` instance, running until
+ a `KeyboardInterrupt` is received.
+ """
import argparse
from bcipy.acquisition.datastream.generator import file_data_generator
@@ -182,8 +207,10 @@ def main():
help="file containing data to be streamed; "
"if missing, random data will be served.")
parser.add_argument('-m', '--markers', action="store_true", default=False)
- parser.add_argument('-n', '--name', default='DSI-24', help='Name of the device spec to mock.')
- parser.add_argument('-c', '--chunk_size', default=0, type=int, help='Chunk size')
+ parser.add_argument('-n', '--name', default='DSI-24',
+ help='Name of the device spec to mock.')
+ parser.add_argument('-c', '--chunk_size', default=0,
+ type=int, help='Chunk size')
args = parser.parse_args()
if args.filename:
diff --git a/bcipy/acquisition/datastream/mock/eye_tracker_server.py b/bcipy/acquisition/datastream/mock/eye_tracker_server.py
index 9cc3c3e73..948d5f6c5 100644
--- a/bcipy/acquisition/datastream/mock/eye_tracker_server.py
+++ b/bcipy/acquisition/datastream/mock/eye_tracker_server.py
@@ -3,6 +3,7 @@
import logging
import math
import time
+from typing import Generator, List
from numpy.random import uniform
@@ -14,7 +15,11 @@
def eye_tracker_device() -> DeviceSpec:
- """Mock DeviceSpec for an eye tracker."""
+ """Mock DeviceSpec for an eye tracker.
+
+ Returns:
+ DeviceSpec: A DeviceSpec object configured for an eye tracker.
+ """
return DeviceSpec(name='EyeTracker',
channels=[
'leftEyeX', 'leftEyeY', 'rightEyeX', 'rightEyeY',
@@ -25,15 +30,24 @@ def eye_tracker_device() -> DeviceSpec:
content_type='Gaze')
-def eye_tracker_data_generator(display_x=1920, display_y=1080):
+def eye_tracker_data_generator(display_x: int = 1920, display_y: int = 1080) -> Generator[List[float], None, None]:
"""Generates sample eye tracker data.
TODO: determine appropriate values for pixelsPerDegree fields.
TODO: look info alternatives; maybe PyGaze.
http://www.pygaze.org/about/
+
+ Args:
+ display_x (int): The width of the display in pixels. Defaults to 1920.
+ display_y (int): The height of the display in pixels. Defaults to 1080.
+
+ Yields:
+ List[float]: A list of float values representing eye tracker data, including
+ left eye X/Y, right eye X/Y, left pupil area, right pupil area,
+ pixels per degree X, and pixels per degree Y.
"""
- def area(diameter):
+ def area(diameter: float) -> float:
return math.pi * (diameter / 2.0)**2
while True:
@@ -50,14 +64,23 @@ def area(diameter):
def eye_tracker_server() -> LslDataServer:
- """Create a demo lsl_server that serves eye tracking data."""
+ """Create a demo lsl_server that serves eye tracking data.
+
+ Returns:
+ LslDataServer: An LslDataServer instance configured for eye tracking data.
+ """
return LslDataServer(device_spec=eye_tracker_device(),
generator=eye_tracker_data_generator())
def main():
- """Create an run an lsl_server"""
+ """Create and run an lsl_server.
+
+ This function initializes and starts an LSL data server for eye tracking.
+ It runs indefinitely until a KeyboardInterrupt is received, at which point
+ the server is stopped.
+ """
try:
server = eye_tracker_server()
log.info("New server created")
diff --git a/bcipy/acquisition/datastream/mock/switch.py b/bcipy/acquisition/datastream/mock/switch.py
index d01d9d11f..94b4ffda3 100644
--- a/bcipy/acquisition/datastream/mock/switch.py
+++ b/bcipy/acquisition/datastream/mock/switch.py
@@ -13,7 +13,11 @@
def switch_device() -> DeviceSpec:
- """Mock DeviceSpec for a switch"""
+ """Mock DeviceSpec for a switch.
+
+ Returns:
+ DeviceSpec: A DeviceSpec object configured for a switch.
+ """
device = preconfigured_device('Switch', strict=False)
if device:
return device
@@ -27,21 +31,27 @@ class Switch:
"""Mock switch which streams data over LSL at an irregular interval."""
def __init__(self):
+ """Initializes the Switch with a DeviceSpec and LSL StreamOutlet."""
super().__init__()
- self.device = switch_device()
- self.lsl_id = 'bci_demo_switch'
+ self.device: DeviceSpec = switch_device()
+ self.lsl_id: str = 'bci_demo_switch'
info = StreamInfo(self.device.name, self.device.content_type,
self.device.channel_count, self.device.sample_rate,
self.device.data_type, self.lsl_id)
- self.outlet = StreamOutlet(info)
+ self.outlet: StreamOutlet = StreamOutlet(info)
+
+ def click(self, _position: list) -> None:
+ """Pushes a sample to the LSL stream when a click event occurs.
- def click(self, _position):
- """Click event that pushes a sample"""
+ Args:
+ _position (list): The position of the click event. This parameter is
+ currently unused.
+ """
log.debug("Click!")
self.outlet.push_sample([1.0])
- def quit(self):
- """Quit and cleanup"""
+ def quit(self) -> None:
+ """Cleans up and releases the LSL StreamOutlet."""
del self.outlet
self.outlet = None
@@ -49,13 +59,19 @@ def quit(self):
class SwitchGui(BCIGui): # pragma: no cover
"""GUI to emulate a switch."""
- def __init__(self, switch: Switch, *args, **kwargs):
+ def __init__(self, switch: 'Switch', *args, **kwargs):
+ """Initializes the SwitchGui with a Switch object.
+
+ Args:
+ switch (Switch): The Switch object to control.
+ *args: Variable length argument list for the base class.
+ **kwargs: Arbitrary keyword arguments for the base class.
+ """
super().__init__(*args, **kwargs)
- self.switch = switch
+ self.switch: Switch = switch
def build_buttons(self) -> None:
- """Build all buttons necessary for the UI.
- """
+ """Build all buttons necessary for the UI."""
self.add_button(message='Click!',
position=[100, 75],
size=[50, 50],
@@ -75,9 +91,15 @@ def build_text(self) -> None:
font_size=16)
-def main(switch: Switch): # pragma: no cover
- """Creates a PyQt5 GUI with a single button in the middle. Performs the
- switch action when clicked."""
+def main(switch: Switch) -> None: # pragma: no cover
+ """Creates a PyQt6 GUI with a single button in the middle and runs it.
+
+ Performs the switch action when the button is clicked and cleans up
+ the switch resources upon exit.
+
+ Args:
+ switch (Switch): The Switch object to associate with the GUI.
+ """
gui = app(sys.argv)
ex = SwitchGui(switch=switch,
title='Demo Switch!',
diff --git a/bcipy/acquisition/datastream/producer.py b/bcipy/acquisition/datastream/producer.py
index ffdc310c9..82e698f66 100644
--- a/bcipy/acquisition/datastream/producer.py
+++ b/bcipy/acquisition/datastream/producer.py
@@ -1,11 +1,12 @@
-"""Code for mocking an EEG data stream. Code in this module produces data
-at a specified frequency."""
+"""Code for mocking an EEG data stream. Code in this module produces data at a specified frequency."""
+
import logging
import random
import threading
import time
from builtins import next
from queue import Queue
+from typing import Any, Iterator, Optional
from bcipy.acquisition.datastream.generator import random_data_generator
from bcipy.config import SESSION_LOG_FILENAME
@@ -16,76 +17,106 @@
class Producer(threading.Thread):
"""Produces generated data at a specified frequency.
- Parameters
- ----------
- queue : Queue
- Generated data will be written to the queue.
- freq : float, optional
- Data will be generated at the given frequency.
- generator : object, optional
- python generator for creating data.
- maxiters : int, optional
- if provided, stops generating data after the given number of iters.
+ This class extends `threading.Thread` to run data generation in a separate
+ thread, pushing generated samples into a queue at a specified frequency.
+ It can also act as a context manager to automatically start and stop the
+ thread.
+
+ Args:
+ queue (Queue): The queue where generated data will be written.
+ freq (float, optional): The frequency (in Hz) at which data will be generated.
+ Defaults to 1/100 (0.01 Hz).
+ generator (Optional[Any], optional): A Python generator for creating data.
+ Defaults to `random_data_generator()`.
+ maxiters (Optional[int], optional): If provided, stops generating data
+ after this many iterations. Defaults to None.
"""
def __init__(self,
- queue,
- freq=1 / 100,
- generator=None,
- maxiters=None):
+ queue: Queue,
+ freq: float = 1 / 100,
+ generator: Optional[Any] = None,
+ maxiters: Optional[int] = None):
super(Producer, self).__init__()
- self.daemon = True
- self._running = True
+ self.daemon: bool = True
+ self._running: bool = True
+
+ self.freq: float = freq
+ self.generator: Any = generator or random_data_generator()
+ self.maxiters: Optional[int] = maxiters
+ self.queue: Queue = queue
+
+ def __enter__(self) -> 'Producer':
+ """Enters the runtime context related to this object.
- self.freq = freq
- self.generator = generator or random_data_generator()
- self.maxiters = maxiters
- self.queue = queue
+ Starts the producer thread when entering the context.
- # @override to make this class a context manager
- def __enter__(self):
+ Returns:
+ Producer: The Producer instance.
+ """
self.start()
return self
- # @override to make this class a context manager
- def __exit__(self, _exc_type, _exc_value, _traceback):
+ def __exit__(self, _exc_type: Any, _exc_value: Any, _traceback: Any) -> None:
+ """Exits the runtime context related to this object.
+
+ Stops the producer thread when exiting the context.
+
+ Args:
+ _exc_type (Any): The exception type, if an exception was raised.
+ _exc_value (Any): The exception value, if an exception was raised.
+ _traceback (Any): The traceback, if an exception was raised.
+ """
self.stop()
- def _genitem(self):
- """Generates the data item to be added to the queue."""
+ def _genitem(self) -> Any:
+ """Generates the data item to be added to the queue.
+ Returns:
+ Any: The next data item from the configured generator.
+ """
return next(self.generator)
- def _additem(self):
- """Adds the data item to the queue."""
+ def _additem(self) -> None:
+ """Adds the data item to the queue.
+ This method fetches an item using `_genitem` and places it into the internal queue.
+ """
self.queue.put(self._genitem())
- def stop(self):
- """Stop the thread; stopped threads cannot be restarted."""
+ def stop(self) -> None:
+ """Stops the thread.
+
+ Sets an internal flag to stop the thread's execution and waits for the thread to finish.
+ Stopped threads cannot be restarted.
+ """
self._running = False
self.join()
- def run(self):
+ def run(self) -> None:
"""Provides a control loop, adding a data item to the queue at the
configured frequency.
- @overrides the Thread run method
+ This method overrides the `threading.Thread.run` method.
"""
- def tick():
+ def tick() -> Iterator[float]:
"""Corrects the time interval if the time of the work to add the
- item causes drift."""
- current_time = time.time()
- count = 0
+ item causes drift.
+
+ Yields:
+ float: The calculated sleep duration to maintain the target frequency.
+ """
+ current_time: float = time.time()
+ count: int = 0
while True:
count += 1
yield max(current_time + count * self.freq - time.time(), 0)
- sleep_len = tick()
- times = 0
+ sleep_len: Iterator[float] = tick()
+ times: int = 0
while self._running and (self.maxiters is None or
times < self.maxiters):
times += 1
@@ -94,25 +125,39 @@ def tick():
class _ConsumerThread(threading.Thread):
- """Consumer used to test the Producer by consuming generated items."""
+ """Consumer used to test the Producer by consuming generated items.
+
+ Args:
+ queue (Queue): The queue from which to consume items.
+ name (Optional[str], optional): The name of the consumer thread.
+ Defaults to None.
+ """
- def __init__(self, queue, name=None):
- super(_ConsumerThread, self).__init__()
- self.daemon = True
- self.name = name
- self._q = queue
+ def __init__(self, queue: Queue, name: Optional[str] = None):
+ super(_ConsumerThread, self).__init__(name=name)
+ self.daemon: bool = True
+ self._q: Queue = queue
- def run(self):
+ def run(self) -> None:
+ """Main loop for the consumer thread.
+
+ Continuously checks the queue for items and processes them, logging
+ the item and the queue size.
+ """
while True:
if not self._q.empty():
- item = self._q.get()
+ item: Any = self._q.get()
log.info('Getting %s: %s items in queue',
str(item), str(self._q.qsize()))
time.sleep(random.random())
-def main():
- """Main method"""
+def main() -> None:
+ """Main method to demonstrate the Producer and Consumer threads.
+
+ Initializes a Producer and a Consumer thread, starts them, and lets them run
+ for a short period (5 seconds) before the program exits.
+ """
data_queue: Queue = Queue()
producer: Producer = Producer(data_queue)
consumer: _ConsumerThread = _ConsumerThread(data_queue)
diff --git a/bcipy/acquisition/devices.py b/bcipy/acquisition/devices.py
index 77308a5a1..bf39917bf 100644
--- a/bcipy/acquisition/devices.py
+++ b/bcipy/acquisition/devices.py
@@ -1,10 +1,10 @@
-"""Functionality for loading and querying configuration for supported hardware
-devices."""
+"""Functionality for loading and querying configuration for supported hardware devices."""
+
import json
import logging
from enum import Enum, auto
from pathlib import Path
-from typing import Dict, List, NamedTuple, Optional, Union
+from typing import Any, Dict, List, NamedTuple, Optional, Union
from bcipy.config import (DEFAULT_ENCODING, DEVICE_SPEC_PATH,
SESSION_LOG_FILENAME)
@@ -23,26 +23,41 @@
class ChannelSpec(NamedTuple):
- """Represents metadata about a channel."""
- name: str # Label in the LSL metadata
- label: str # Label used within BciPy (raw_data, etc.)
- type: str = None
- units: str = None
+ """Represents metadata about a channel.
+
+ Attributes:
+ name (str): Label in the LSL metadata.
+ label (str): Label used within BciPy (raw_data, etc.).
+ type (Optional[str]): Type of the channel (e.g., 'EEG', 'Pupil'). Defaults to None.
+ units (Optional[str]): Units of measurement for the channel data (e.g., 'microvolts').
+ Defaults to None.
+ """
+ name: str
+ label: str
+ type: Optional[str] = None
+ units: Optional[str] = None
- def __repr__(self):
+ def __repr__(self) -> str:
+ """Returns a string representation of the ChannelSpec object."""
fields = ['name', 'label', 'type', 'units']
items = [(field, self.__getattribute__(field)) for field in fields]
props = [f"{key}='{val}'" for key, val in items if val]
- return f"ChannelSpec({', '.join(props)})"
+ return f"ChannelSpec({ ', '.join(props) })"
+
+def channel_spec(channel: Union[str, Dict[str, Any], ChannelSpec]) -> ChannelSpec:
+ """Creates a ChannelSpec from the given channel information.
-def channel_spec(channel: Union[str, dict, ChannelSpec]) -> ChannelSpec:
- """Creates a ChannelSpec from the given channel.
+ Args:
+ channel (Union[str, Dict[str, Any], ChannelSpec]): Acquisition channel
+ information, specified as either just the label (str), a dictionary
+ with channel properties, or an existing ChannelSpec object.
- Parameters
- ----------
- channel - acquisition channel information specified as either just the
- label or with additional data represented by a dict or ChannelSpec.
+ Returns:
+ ChannelSpec: A `ChannelSpec` object.
+
+ Raises:
+ Exception: If an unexpected channel type is provided.
"""
if isinstance(channel, str):
return ChannelSpec(name=channel, label=channel)
@@ -59,40 +74,49 @@ class DeviceStatus(Enum):
PASSIVE = auto()
def __str__(self) -> str:
- """String representation"""
+ """Returns the lowercase string representation of the DeviceStatus enum member."""
return self.name.lower()
@classmethod
def from_str(cls, name: str) -> 'DeviceStatus':
- """Returns the DeviceStatus associated with the given string
- representation."""
+ """Returns the `DeviceStatus` enum member associated with the given string
+ representation.
+
+ Args:
+ name (str): The string representation of the device status (e.g., 'active', 'passive').
+
+ Returns:
+ DeviceStatus: The corresponding `DeviceStatus` enum member.
+ """
return cls[name.upper()]
class DeviceSpec:
"""Specification for a hardware device used in data acquisition.
- Parameters
- ----------
- name - device short name; ex. DSI-24
- channels - list of data collection channels; devices must have at least
- one channel. Channels may be provided as a list of names or list of
- ChannelSpecs.
- sample_rate - sample frequency in Hz.
- content_type - type of device; likely one of ['EEG', 'MoCap', 'Gaze',
- 'Audio', 'Markers']; see https://github.com/sccn/xdf/wiki/Meta-Data.
- description - device description
- ex. 'Wearable Sensing DSI-24 dry electrode EEG headset'
- data_type - data format of a channel; all channels must have the same type;
- see https://labstreaminglayer.readthedocs.io/projects/liblsl/ref/enums.html
- excluded_from_analysis - list of channels (label) to exclude from analysis.
- status - recording status
- static_offset - Specifies the static trigger offset (in seconds) used to align
- triggers properly with EEG data from LSL. The system includes built-in
- offset correction, but there is still a hardware-limited offset between EEG
- and trigger timing values for which the system does not account. The correct
- value may be different for each computer, and must be determined on a
- case-by-case basis. Default: 0.1",
+ Args:
+ name (str): Device short name (e.g., 'DSI-24').
+ channels (Union[List[str], List[ChannelSpec], List[dict]]): A list of
+ data collection channels. Devices must have at least one channel.
+ Channels may be provided as a list of names (str), a list of
+ `ChannelSpec` objects, or a list of dictionaries representing channel properties.
+ sample_rate (int): Sample frequency in Hz.
+ content_type (str, optional): Type of device (e.g., 'EEG', 'MoCap', 'Gaze',
+ 'Audio', 'Markers'). Defaults to `DEFAULT_DEVICE_TYPE` ('EEG').
+ See https://github.com/sccn/xdf/wiki/Meta-Data.
+ description (Optional[str], optional): Device description (e.g.,
+ 'Wearable Sensing DSI-24 dry electrode EEG headset').
+ Defaults to `name` if not provided.
+ excluded_from_analysis (Optional[List[str]], optional): A list of channel labels
+ to exclude from analysis.
+ Defaults to an empty list.
+ data_type (str, optional): Data format of a channel. All channels must have the same type.
+ Defaults to 'float32'. See https://labstreaminglayer.readthedocs.io/projects/liblsl/ref/enums.html.
+ status (DeviceStatus, optional): Recording status of the device.
+ Defaults to `DeviceStatus.ACTIVE`.
+ static_offset (float, optional): Specifies the static trigger offset (in seconds)
+ used to align triggers properly with EEG data from LSL.
+ Defaults to `DEFAULT_STATIC_OFFSET` (0.1).
"""
def __init__(self,
@@ -109,38 +133,40 @@ def __init__(self,
assert sample_rate >= 0, "Sample rate can't be negative."
assert data_type in SUPPORTED_DATA_TYPES
- self.name = name
- self.channel_specs = [channel_spec(ch) for ch in channels]
- self.sample_rate = int(sample_rate)
- self.content_type = content_type
- self.description = description or name
- self.data_type = data_type
- self.excluded_from_analysis = excluded_from_analysis or []
+ self.name: str = name
+ self.channel_specs: List[ChannelSpec] = [
+ channel_spec(ch) for ch in channels]
+ self.sample_rate: int = int(sample_rate)
+ self.content_type: str = content_type
+ self.description: str = description or name
+ self.data_type: str = data_type
+ self.excluded_from_analysis: List[str] = excluded_from_analysis or []
self._validate_excluded_channels()
- self.status = status
- self.static_offset = static_offset
+ self.status: DeviceStatus = status
+ self.static_offset: float = static_offset
@property
def channel_count(self) -> int:
- """Number of channels"""
+ """Returns the number of channels for the device."""
return len(self.channel_specs)
@property
def channels(self) -> List[str]:
- """List of channel labels. These may be customized for BciPy."""
+ """Returns a list of channel labels, which may be customized for BciPy."""
return [ch.label for ch in self.channel_specs]
@property
def channel_names(self) -> List[str]:
- """List of channel names from the device."""
+ """Returns a list of channel names as reported by the device."""
return [ch.name for ch in self.channel_specs]
@property
def analysis_channels(self) -> List[str]:
- """List of channels used for analysis by the signal module.
- Parameters:
- -----------
- exclude_trg - indicates whether or not to exclude a TRG channel if present.
+ """Returns a list of channels used for analysis by the signal module.
+
+ Returns:
+ List[str]: A list of channel labels to be used for analysis, excluding
+ any channels specified in `excluded_from_analysis`.
"""
return list(
@@ -149,12 +175,19 @@ def analysis_channels(self) -> List[str]:
@property
def is_active(self) -> bool:
- """Returns a boolean indicating if the device is currently active
- (recording status set to DeviceStatus.ACTIVE)."""
+ """Checks if the device is currently active (recording status set to `DeviceStatus.ACTIVE`).
+
+ Returns:
+ bool: True if the device is active, False otherwise.
+ """
return self.status == DeviceStatus.ACTIVE
- def to_dict(self) -> dict:
- """Converts the DeviceSpec to a dict."""
+ def to_dict(self) -> Dict[str, Any]:
+ """Converts the `DeviceSpec` object to a dictionary representation.
+
+ Returns:
+ Dict[str, Any]: A dictionary containing the device's properties.
+ """
return {
'name': self.name,
'content_type': self.content_type,
@@ -166,21 +199,25 @@ def to_dict(self) -> dict:
'static_offset': self.static_offset
}
- def __str__(self):
- """Custom str representation."""
+ def __str__(self) -> str:
+ """Returns a custom string representation of the `DeviceSpec` object."""
names = [
'name', 'content_type', 'channels', 'sample_rate', 'description'
]
- def quoted_value(name):
+ def quoted_value(name: str) -> Union[str, Any]:
value = self.__getattribute__(name)
return f"'{value}'" if isinstance(value, str) else value
props = [f"{name}={quoted_value(name)}" for name in names]
- return f"DeviceSpec({', '.join(props)})"
+ return f"DeviceSpec({ ', '.join(props) })"
- def _validate_excluded_channels(self):
- """Warn if excluded channels are not in the list of channels"""
+ def _validate_excluded_channels(self) -> None:
+ """Warns if any excluded channels are not found in the device's channel list.
+
+ This method logs a warning for each channel in `excluded_from_analysis`
+ that does not exist in `self.channels`.
+ """
for channel in self.excluded_from_analysis:
if channel not in self.channels:
logger.warning(
@@ -188,9 +225,18 @@ def _validate_excluded_channels(self):
)
-def make_device_spec(config: dict) -> DeviceSpec:
- """Constructs a DeviceSpec from a dict. Throws a KeyError if any fields
- are missing."""
+def make_device_spec(config: Dict[str, Any]) -> DeviceSpec:
+ """Constructs a `DeviceSpec` object from a dictionary configuration.
+
+ Args:
+ config (Dict[str, Any]): A dictionary containing device configuration parameters.
+
+ Returns:
+ DeviceSpec: A `DeviceSpec` object initialized with the provided configuration.
+
+ Raises:
+ KeyError: If any required fields are missing in the `config` dictionary.
+ """
default_status = str(DeviceStatus.ACTIVE)
return DeviceSpec(name=config['name'],
content_type=config['content_type'],
@@ -199,19 +245,23 @@ def make_device_spec(config: dict) -> DeviceSpec:
description=config['description'],
excluded_from_analysis=config.get(
'excluded_from_analysis', []),
- status=DeviceStatus.from_str(config.get('status', default_status)),
+ status=DeviceStatus.from_str(
+ config.get('status', default_status)),
static_offset=config.get('static_offset', DEFAULT_STATIC_OFFSET))
-def load(config_path: Path = Path(DEFAULT_CONFIG), replace: bool = False) -> Dict[str, DeviceSpec]:
- """Load the list of supported hardware for data acquisition from the given
+def load(config_path: Path = Path(DEFAULT_CONFIG), replace: bool = False) -> Dict[str, 'DeviceSpec']:
+ """Loads the list of supported hardware devices for data acquisition from a
configuration file.
- Parameters
- ----------
- config_path - path to the devices json file
- replace - optional; if true, existing devices are replaced; if false,
- values will be overwritten or appended.
+ Args:
+ config_path (Path, optional): Path to the devices JSON file.
+ Defaults to `Path(DEFAULT_CONFIG)`.
+ replace (bool, optional): If True, existing devices are replaced; if False,
+ values will be overwritten or appended. Defaults to False.
+
+ Returns:
+ Dict[str, DeviceSpec]: A dictionary of loaded `DeviceSpec` objects, keyed by device name.
"""
global _SUPPORTED_DEVICES
@@ -225,9 +275,14 @@ def load(config_path: Path = Path(DEFAULT_CONFIG), replace: bool = False) -> Dic
return _SUPPORTED_DEVICES
-def preconfigured_devices() -> Dict[str, DeviceSpec]:
- """Returns the preconfigured devices keyed by name. If no devices have yet
- been configured, loads and returns the DEFAULT_CONFIG."""
+def preconfigured_devices() -> Dict[str, 'DeviceSpec']:
+ """Returns the preconfigured devices, keyed by name.
+
+ If no devices have yet been configured, it loads and returns the `DEFAULT_CONFIG`.
+
+ Returns:
+ Dict[str, DeviceSpec]: A dictionary of preconfigured `DeviceSpec` objects.
+ """
global _SUPPORTED_DEVICES
if not _SUPPORTED_DEVICES:
load()
@@ -235,33 +290,55 @@ def preconfigured_devices() -> Dict[str, DeviceSpec]:
def preconfigured_device(name: str, strict: bool = True) -> DeviceSpec:
- """Retrieve the DeviceSpec with the given name. An exception is raised
- if the device is not found."""
+ """Retrieves the `DeviceSpec` with the given name.
+
+ Args:
+ name (str): The name of the device to retrieve.
+ strict (bool, optional): If True, raises an exception if the device is not found.
+ Defaults to True.
+
+ Returns:
+ DeviceSpec: The `DeviceSpec` object for the specified device.
+
+ Raises:
+ ValueError: If `strict` is True and the device is not found.
+ """
device = preconfigured_devices().get(name, None)
if strict and not device:
current = ', '.join(
- [f"'{key}'" for key, _ in preconfigured_devices().items()])
+ [f"' {key}'" for key, _ in preconfigured_devices().items()])
msg = (
- f"Device not found: {name}."
- "\n\n"
- f"The current list of devices includes the following: {current}."
- "\n"
+ f"Device not found: {name}.\n\n"
+ f"The current list of devices includes the following: {current}.\n"
"You may register new devices using the device module `register` function or in bulk"
" using `load`.")
logger.error(msg)
raise ValueError(msg)
- return device
+ return device # type: ignore
def with_content_type(content_type: str) -> List[DeviceSpec]:
- """Retrieve the list of DeviceSpecs with the given content_type."""
+ """Retrieves a list of `DeviceSpec` objects with the given content type.
+
+ Args:
+ content_type (str): The content type to filter devices by.
+
+ Returns:
+ List[DeviceSpec]: A list of `DeviceSpec` objects matching the specified content type.
+ """
return [
spec for spec in preconfigured_devices().values()
if spec.content_type == content_type
]
-def register(device_spec: DeviceSpec):
- """Register the given DeviceSpec."""
+def register(device_spec: DeviceSpec) -> None:
+ """Registers the given `DeviceSpec`.
+
+ Adds the provided `DeviceSpec` to the collection of preconfigured devices.
+
+ Args:
+ device_spec (DeviceSpec): The `DeviceSpec` object to register.
+ """
config = preconfigured_devices()
config[device_spec.name] = device_spec
diff --git a/bcipy/acquisition/exceptions.py b/bcipy/acquisition/exceptions.py
index 7caaf1f54..de13064e4 100644
--- a/bcipy/acquisition/exceptions.py
+++ b/bcipy/acquisition/exceptions.py
@@ -1,12 +1,19 @@
-
class InvalidClockError(Exception):
- def __init__(self, msg):
- Exception.__init__(self, msg)
+ """Exception raised for invalid clock operations in acquisition."""
+
+ def __init__(self, msg: str):
+ """Initializes the InvalidClockError with a message.
+
+ Args:
+ msg (str): The error message.
+ """
+ super().__init__(msg)
class UnsupportedContentType(Exception):
"""Error that occurs when attempting to collect data from a device with a
- content type that is not yet supported by BciPy."""
+ content type that is not yet supported by BciPy.
+ """
class InsufficientDataException(Exception):
diff --git a/bcipy/acquisition/marker_writer.py b/bcipy/acquisition/marker_writer.py
index fd6216014..d45616de4 100644
--- a/bcipy/acquisition/marker_writer.py
+++ b/bcipy/acquisition/marker_writer.py
@@ -1,6 +1,6 @@
"""Defines classes that can write markers to LabStreamingLayer StreamOutlet."""
import logging
-from typing import Any
+from typing import Any, Optional
import pylsl
@@ -9,47 +9,68 @@
log = logging.getLogger(SESSION_LOG_FILENAME)
-class MarkerWriter():
+class MarkerWriter:
"""Abstract base class for an object that can be used to handle stimulus
markers.
"""
- def push_marker(self, marker: Any):
- """Push the given stimulus marker for processing.
+ def push_marker(self, marker: Any) -> None:
+ """Pushes the given stimulus marker for processing.
- Parameters
- ----------
- - marker : any object that can be converted to a str
+ This method must be implemented by subclasses.
+
+ Args:
+ marker (Any): Any object that can be converted to a string, representing
+ the stimulus marker.
+
+ Raises:
+ NotImplementedError: If the method is not implemented by a subclass.
"""
raise NotImplementedError()
- def cleanup(self):
- """Performs any necessary cleanup"""
+ def cleanup(self) -> None:
+ """Performs any necessary cleanup.
+
+ This method must be implemented by subclasses.
+
+ Raises:
+ NotImplementedError: If the method is not implemented by a subclass.
+ """
raise NotImplementedError()
class LslMarkerWriter(MarkerWriter):
- """Writes stimulus markers to a LabStreamingLayer StreamOutlet
- using pylsl. To consume this data, the client code would need to create a
- pylsl.StreamInlet. See https://github.com/sccn/labstreaminglayer/wiki.
+ """Writes stimulus markers to a LabStreamingLayer StreamOutlet using pylsl.
+
+ To consume this data, client code would typically create a `pylsl.StreamInlet`.
+ See https://github.com/sccn/labstreaminglayer/wiki for more information.
+
+ Args:
+ stream_name (str, optional): The name of the LSL stream. Defaults to
+ "BCI_Stimulus_Markers".
+ stream_id (str, optional): The unique ID of the LSL stream. Defaults to
+ "bci_stim_markers".
"""
def __init__(self,
stream_name: str = "BCI_Stimulus_Markers",
stream_id: str = "bci_stim_markers"):
- super(LslMarkerWriter, self).__init__()
- self.stream_name = stream_name
+ super().__init__()
+ self.stream_name: str = stream_name
markers_info = pylsl.StreamInfo(stream_name, "Markers", 1, 0, 'string',
stream_id)
- self.markers_outlet = pylsl.StreamOutlet(markers_info)
- self.first_marker_stamp: float = None
+ self.markers_outlet: pylsl.StreamOutlet = pylsl.StreamOutlet(markers_info)
+ self.first_marker_stamp: Optional[float] = None
- def push_marker(self, marker: Any):
- """Push the given stimulus marker for processing.
+ def push_marker(self, marker: Any) -> None:
+ """Pushes the given stimulus marker to the LSL stream.
- Parameters
- ----------
- - marker : any object that can be converted to a str
+ The marker is converted to a string and sent along with a local timestamp.
+ The `first_marker_stamp` is set upon the first marker push.
+
+ Args:
+ marker (Any): Any object that can be converted to a string, representing
+ the stimulus marker.
"""
stamp = pylsl.local_clock()
log.info(f'Pushing marker {str(marker)} at {stamp}')
@@ -57,27 +78,28 @@ def push_marker(self, marker: Any):
if not self.first_marker_stamp:
self.first_marker_stamp = stamp
- def cleanup(self):
- """Cleans up the StreamOutlet."""
+ def cleanup(self) -> None:
+ """Cleans up and releases the LSL StreamOutlet."""
del self.markers_outlet
class NullMarkerWriter(MarkerWriter):
"""MarkerWriter which doesn't write anything.
- A NullMarkerWriter can be passed in to the calling object in scenarios
- where marker handling occurs indirectly (ex. through a trigger box). By
- using a NullMarkerWriter rather than a None value, the calling
- object does not have to do additional null checks and a separation
- of concerns is maintained regarding how triggers are written for different
- devices.
+ A `NullMarkerWriter` can be passed to calling objects in scenarios where
+ marker handling occurs indirectly (e.g., through a trigger box). By using
+ a `NullMarkerWriter` instead of `None`, the calling object avoids additional
+ null checks, maintaining a separation of concerns regarding how triggers
+ are written for different devices.
See the Null Object Design Pattern:
https://en.wikipedia.org/wiki/Null_object_pattern
"""
- def push_marker(self, marker: Any):
+ def push_marker(self, marker: Any) -> None:
+ """Overrides the abstract method to do nothing."""
pass
def cleanup(self):
+ """Overrides the abstract method to do nothing."""
pass
diff --git a/bcipy/acquisition/multimodal.py b/bcipy/acquisition/multimodal.py
index dba7a7423..ad4ad2a5e 100644
--- a/bcipy/acquisition/multimodal.py
+++ b/bcipy/acquisition/multimodal.py
@@ -15,13 +15,15 @@
class ContentType(AutoNumberEnum):
- """Enum of supported acquisition device (LSL) content types. Allows for
- case-insensitive matching, as well as synonyms for some types.
+ """Enum of supported acquisition device (LSL) content types.
- >>> ContentType(1) == ContentType.EEG
- True
- >>> ContentType('Eeg') == ContentType.EEG
- True
+ Allows for case-insensitive matching, as well as synonyms for some types.
+
+ Examples:
+ >>> ContentType(1) == ContentType.EEG
+ True
+ >>> ContentType('Eeg') == ContentType.EEG
+ True
"""
def __init__(self, synonyms: List[str]):
@@ -33,34 +35,40 @@ def __init__(self, synonyms: List[str]):
@classmethod
def _missing_(cls, value: Any) -> 'ContentType':
- """Lookup function used when a value is not found."""
+ """Lookup function used when a value is not found.
+
+ This method enables case-insensitive matching and allows for synonyms.
+
+ Args:
+ value (Any): The value to lookup, which will be converted to a string and lowercased.
+
+ Returns:
+ ContentType: The matching ContentType enum member.
+
+ Raises:
+ UnsupportedContentType: If no matching content type is found.
+ """
value = str(value).lower()
- for member in cls:
- if member.name.lower() == value or value in member.synonyms:
+ for member in cls: # type: ignore
+ if member.name.lower() == value or value in member.synonyms: # type: ignore
return member
raise UnsupportedContentType(f"ContentType not supported: {value}")
class ClientManager():
- """Manages multiple acquisition clients. This class can also act as an
- acquisition client. If used in this way, it dispatches to the managed
- client with the default_client_type.
-
- >>> from bcipy.acquisition import LslAcquisitionClient
- >>> from bcipy.acquisition.devices import DeviceSpec
- >>> spec = DeviceSpec('Testing', ['ch1', 'ch2', 'ch3'], 60.0, 'EEG')
- >>> manager = ClientManager()
- >>> eeg_client = LslAcquisitionClient(device_spec=spec)
- >>> manager.add_client(eeg_client)
- >>> manager.device_spec == spec
- True
-
- Parameters
- ----------
- default_content_type - used for dispatching calls to an LslClient.
+ """Manages multiple acquisition clients.
+
+ This class can also act as an acquisition client. If used in this way,
+ it dispatches to the managed client with the `default_client_type`.
+
+ Args:
+ default_content_type (ContentType, optional): The default content type
+ to use for dispatching calls to an `LslClient`. Defaults to `ContentType.EEG`.
"""
- def __init__(self, default_content_type: ContentType = ContentType.EEG) -> None:
+ inlet: Any = None # To satisfy mypy, as ClientManager can act as an LslAcquisitionClient
+
+ def __init__(self, default_content_type: ContentType = ContentType.EEG) -> None: # type: ignore
self._clients: Dict[ContentType, LslAcquisitionClient] = {}
self.default_content_type = default_content_type
@@ -71,27 +79,25 @@ def clients(self) -> List[LslAcquisitionClient]:
@property
def clients_by_type(self) -> Dict[ContentType, LslAcquisitionClient]:
- """Returns a dict of clients keyed by their content type"""
+ """Returns a dictionary of clients keyed by their content type."""
return self._clients
@property
def device_specs(self) -> List[DeviceSpec]:
- """Returns a list of DeviceSpecs for all the clients."""
- return [client.device_spec for client in self.clients]
+ """Returns a list of `DeviceSpec` objects for all the clients."""
+ return [client.device_spec for client in self.clients if client.device_spec] # type: ignore
@property
def device_content_types(self) -> List[ContentType]:
- """Returns a list of ContentTypes provided by the configured devices.
- """
+ """Returns a list of `ContentType` enums provided by the configured devices."""
return list(self._clients.keys())
@property
def active_device_content_types(self) -> List[ContentType]:
- """Returns a list of ContentTypes provided by the active configured
- devices."""
+ """Returns a list of `ContentType` enums provided by the active configured devices."""
return [
content_type for content_type, client in self._clients.items()
- if client.device_spec.is_active
+ if client.device_spec and client.device_spec.is_active # type: ignore
]
@property
@@ -99,24 +105,38 @@ def default_client(self) -> Optional[LslAcquisitionClient]:
"""Returns the default client."""
return self.get_client(self.default_content_type)
- def add_client(self, client: LslAcquisitionClient):
- """Add the given client to the manager."""
- content_type = ContentType(client.device_spec.content_type)
+ def add_client(self, client: LslAcquisitionClient) -> None: # type: ignore
+ """Adds the given client to the manager.
+
+ Args:
+ client (LslAcquisitionClient): The client instance to add.
+ """
+ content_type = ContentType(
+ client.device_spec.content_type) # type: ignore
self._clients[content_type] = client
def get_client(
self, content_type: ContentType) -> Optional[LslAcquisitionClient]:
- """Get client by content type"""
+ """Retrieves a client by its content type.
+
+ Args:
+ content_type (ContentType): The content type of the client to retrieve.
+
+ Returns:
+ Optional[LslAcquisitionClient]: The `LslAcquisitionClient` instance if found,
+ otherwise None.
+ """
return self._clients.get(content_type, None)
- def start_acquisition(self):
- """Start acquiring data for all clients"""
+ def start_acquisition(self) -> None: # type: ignore
+ """Starts data acquisition for all clients."""
for client in self.clients:
+ # type: ignore
logger.info(f"Connecting to {client.device_spec.name}...")
client.start_acquisition()
- def stop_acquisition(self):
- """Stop acquiring data for all clients"""
+ def stop_acquisition(self) -> None: # type: ignore
+ """Stops data acquisition for all clients."""
logger.info("Stopping acquisition...")
for client in self.clients:
client.stop_acquisition()
@@ -128,54 +148,83 @@ def get_data_by_device(
content_types: Optional[List[ContentType]] = None,
strict: bool = True
) -> Dict[ContentType, List[Record]]:
- """Get data for one or more devices. The number of samples for each
- device depends on the sample rate and may be different for item.
-
- Parameters
- ----------
- start - start time (acquisition clock) of data window; NOTE: the
- actual start time will be adjusted to by the static_offset
- configured for each device.
- seconds - duration of data to return for each device
- content_types - specifies which devices to include; if not
- unspecified, data for all types is returned.
- strict - if True, raises an exception if the returned rows is
- less than the requested number of records.
+ """Retrieves data for one or more devices within a specified time window.
+
+ The number of samples for each device depends on the sample rate and may
+ differ. The actual start time will be adjusted by the `static_offset`
+ configured for each device.
+
+ Args:
+ start (Optional[float]): Start time (acquisition clock) of the data window.
+ seconds (Optional[float]): Duration of data to return for each device.
+ content_types (Optional[List[ContentType]], optional): Specifies which
+ devices to include. If None, data for all types is returned.
+ Defaults to None.
+ strict (bool, optional): If True, raises an `InsufficientDataException`
+ if the number of returned records is less than the requested number.
+ Defaults to True.
+
+ Returns:
+ Dict[ContentType, List[Record]]: A dictionary where keys are `ContentType`
+ and values are lists of `Record` objects.
+
+ Raises:
+ InsufficientDataException: If `strict` is True and the returned data count
+ is less than the requested count for a device.
"""
- output = {}
+ output: Dict[ContentType, List[Record]] = {}
if not content_types:
content_types = self.device_content_types
for content_type in content_types:
name = content_type.name
client = self.get_client(content_type)
- adjusted_start = start + client.device_spec.static_offset
- if client.device_spec.sample_rate > 0:
- count = round(seconds * client.device_spec.sample_rate)
- logger.info(f'Need {count} records for processing {name} data')
- output[content_type] = client.get_data(start=adjusted_start,
- limit=count)
- data_count = len(output[content_type])
- if strict and data_count < count:
- msg = f'Needed {count} {name} records but received {data_count}'
- logger.error(msg)
- raise InsufficientDataException(msg)
+ if client and client.device_spec:
+ adjusted_start = start + client.device_spec.static_offset
+ if client.device_spec.sample_rate > 0:
+ count = round(seconds * client.device_spec.sample_rate)
+ logger.info(
+ f'Need {count} records for processing {name} data')
+ output[content_type] = client.get_data(start=adjusted_start,
+ limit=count)
+ data_count = len(output[content_type])
+ if strict and data_count < count:
+ msg = f'Needed {count} {name} records but received {data_count}'
+ logger.error(msg)
+ raise InsufficientDataException(msg)
+ else:
+ # Markers have an IRREGULAR_RATE.
+ logger.info(f'Querying {name} data')
+ output[content_type] = client.get_data(start=adjusted_start,
+ end=adjusted_start + seconds)
+ logger.info(
+ f"Received {len(output[content_type])} records.")
else:
- # Markers have an IRREGULAR_RATE.
- logger.info(f'Querying {name} data')
- output[content_type] = client.get_data(start=adjusted_start,
- end=adjusted_start + seconds)
- logger.info(f"Received {len(output[content_type])} records.")
+ logger.error(
+ f"No client and device spec found for content type: {content_type.name}")
return output
- def cleanup(self):
- """Perform any cleanup tasks"""
+ def cleanup(self) -> None: # type: ignore
+ """Performs any necessary cleanup tasks for all managed clients."""
for client in self.clients:
client.cleanup()
def __getattr__(self, name: str) -> Any:
- """Dispatch unknown properties and methods to the client with the
- default content type."""
+ """Dispatches unknown properties and methods to the client with the
+ default content type.
+
+ This allows `ClientManager` to act as a proxy for the default client.
+
+ Args:
+ name (str): The name of the attribute being accessed.
+
+ Returns:
+ Any: The attribute from the default client.
+
+ Raises:
+ AttributeError: If the default client is not set or the attribute
+ does not exist on the default client.
+ """
client = self.default_client
if client:
return client.__getattribute__(name)
diff --git a/bcipy/acquisition/protocols/lsl/connect.py b/bcipy/acquisition/protocols/lsl/connect.py
index 96e49ba35..097a76fc0 100644
--- a/bcipy/acquisition/protocols/lsl/connect.py
+++ b/bcipy/acquisition/protocols/lsl/connect.py
@@ -9,7 +9,24 @@
def resolve_device_stream(
device_spec: Optional[DeviceSpec] = None) -> StreamInfo:
- """Get the LSL stream for the given device."""
+ """Resolves and returns the LSL stream for the given device.
+
+ This function searches for an LSL stream based on the `content_type` of the
+ provided `DeviceSpec`. If no `DeviceSpec` is provided, it defaults to
+ `DEFAULT_DEVICE_TYPE`.
+
+ Args:
+ device_spec (Optional[DeviceSpec], optional): The DeviceSpec object
+ containing the content type
+ of the stream to resolve.
+ Defaults to None.
+
+ Returns:
+ StreamInfo: The resolved LSL `StreamInfo` object.
+
+ Raises:
+ Exception: If an LSL stream with the specified content type is not found.
+ """
content_type = device_spec.content_type if device_spec else DEFAULT_DEVICE_TYPE
streams = resolve_stream('type', content_type)
if not streams:
@@ -19,7 +36,14 @@ def resolve_device_stream(
def device_from_metadata(metadata: StreamInfo) -> DeviceSpec:
- """Create a device_spec from the data stream metadata."""
+ """Creates a `DeviceSpec` object from LSL stream metadata.
+
+ Args:
+ metadata (StreamInfo): The LSL `StreamInfo` object containing device metadata.
+
+ Returns:
+ DeviceSpec: A `DeviceSpec` object populated with information from the metadata.
+ """
return DeviceSpec(name=metadata.name(),
channels=channel_names(metadata),
sample_rate=metadata.nominal_srate(),
diff --git a/bcipy/acquisition/protocols/lsl/lsl_client.py b/bcipy/acquisition/protocols/lsl/lsl_client.py
index 4732b7a12..d0a14de81 100644
--- a/bcipy/acquisition/protocols/lsl/lsl_client.py
+++ b/bcipy/acquisition/protocols/lsl/lsl_client.py
@@ -1,7 +1,7 @@
"""DataAcquisitionClient for LabStreamingLayer data sources."""
import logging
from multiprocessing import Queue
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
import pandas as pd
from pylsl import StreamInlet, local_clock, resolve_byprop
@@ -25,7 +25,19 @@
def time_range(stamps: List[float],
precision: int = 3,
sep: str = " to ") -> str:
- """Utility for printing a range of timestamps"""
+ """Utility for formatting a range of timestamps into a string.
+
+ Args:
+ stamps (List[float]): A list of timestamps.
+ precision (int, optional): The number of decimal places for rounding timestamps.
+ Defaults to 3.
+ sep (str, optional): The separator string between the start and end timestamps.
+ Defaults to " to ".
+
+ Returns:
+ str: A string representing the range of timestamps (e.g., "1.234 to 5.678"),
+ or an empty string if `stamps` is empty.
+ """
if stamps:
return "".join([
str(round(stamps[0], precision)), sep,
@@ -35,8 +47,17 @@ def time_range(stamps: List[float],
def request_desc(start: Optional[float], end: Optional[float],
- limit: Optional[int]):
- """Returns a description of the request which can be logged."""
+ limit: Optional[int]) -> str:
+ """Returns a descriptive string of a data request for logging purposes.
+
+ Args:
+ start (Optional[float]): The starting timestamp of the request.
+ end (Optional[float]): The ending timestamp of the request.
+ limit (Optional[int]): The maximum number of records requested.
+
+ Returns:
+ str: A formatted string describing the data request.
+ """
start_str = round(start, 3) if start else "None"
end_str = round(end, 3) if end else "None"
return f"Requesting data from: {start_str} to: {end_str} limit: {limit}"
@@ -44,28 +65,35 @@ def request_desc(start: Optional[float], end: Optional[float],
class LslAcquisitionClient:
"""Data Acquisition Client for devices streaming data using Lab Streaming
- Layer. Its primary use is dynamically querying streaming data in realtime,
- however, if the save_directory and filename parameters are provided it uses
- a LslRecordingThread to persist the data.
-
- Parameters
- ----------
- max_buffer_len: the maximum length of data to be queried. For continuously
- streaming data this is the number of seconds of data to retain. For
- irregular data, specify the number of samples. When using the RSVP
- paradigm, the max_buffer_len should be large enough to store data for
- the entire inquiry.
- device_spec: spec for the device from which to query data; if missing,
- this class will attempt to find the first EEG stream.
- save_directory: if present, persists the data to the given location.
- raw_data_file_name: if present, uses this name for the data file.
+ Layer.
+
+ Its primary use is dynamically querying streaming data in realtime.
+ If `save_directory` and `filename` parameters are provided, it uses a
+ `LslRecordingThread` to persist the data.
+
+ Args:
+ max_buffer_len (float, optional): The maximum length of data to be queried.
+ For continuously streaming data, this is the
+ number of seconds of data to retain. For
+ irregular data, it specifies the number of samples.
+ When using the RSVP paradigm, `max_buffer_len`
+ should be large enough to store data for the
+ entire inquiry. Defaults to 1.
+ device_spec (Optional[DeviceSpec], optional): The `DeviceSpec` for the device
+ from which to query data.
+ If missing, this class will attempt
+ to find the first EEG stream. Defaults to None.
+ save_directory (Optional[str], optional): If present, persists the data to
+ the given location. Defaults to None.
+ raw_data_file_name (Optional[str], optional): If present, uses this name
+ for the data file. Defaults to None.
"""
- inlet: StreamInlet = None
- recorder: LslRecordingThread = None
- buffer: RingBuffer = None
- _first_sample_time: float = None
- experiment_clock: Clock = None
+ inlet: Optional[StreamInlet] = None
+ recorder: Optional[LslRecordingThread] = None
+ buffer: Optional[RingBuffer] = None
+ _first_sample_time: Optional[float] = None
+ experiment_clock: Optional[Clock] = None
def __init__(self,
max_buffer_len: float = 1,
@@ -74,42 +102,59 @@ def __init__(self,
raw_data_file_name: Optional[str] = None):
super().__init__()
- self.device_spec = device_spec
- self.max_buffer_len = max_buffer_len
- self.save_directory = save_directory
- self.raw_data_file_name = raw_data_file_name
- self._max_samples = None
+ self.device_spec: Optional[DeviceSpec] = device_spec
+ self.max_buffer_len: float = max_buffer_len
+ self.save_directory: Optional[str] = save_directory
+ self.raw_data_file_name: Optional[str] = raw_data_file_name
+ self._max_samples: Optional[int] = None
@property
def has_irregular_rate(self) -> bool:
- """Returns true for sampling devices with an irregular rate,
- such as markers.
+ """Checks if the device has an irregular sampling rate.
+
+ Returns:
+ bool: True if the device's sample rate is `IRREGULAR_RATE`,
+ False otherwise.
"""
- return self.device_spec.sample_rate == IRREGULAR_RATE
+ return self.device_spec.sample_rate == IRREGULAR_RATE # type: ignore
@property
- def first_sample_time(self) -> float:
- """Timestamp returned by the first sample. If the data is being
- recorded this value reflects the timestamp of the first recorded sample"""
+ def first_sample_time(self) -> Optional[float]:
+ """Returns the timestamp of the first sample.
+
+ If data is being recorded, this value reflects the timestamp of the first
+ recorded sample.
+
+ Returns:
+ Optional[float]: The timestamp of the first sample, or None if not set.
+ """
return self._first_sample_time
@property
def max_samples(self) -> int:
- """Maximum number of samples available at any given time."""
+ """Calculates the maximum number of samples available at any given time.
+
+ This depends on `max_buffer_len` and the device's sample rate.
+
+ Returns:
+ int: The maximum number of samples.
+ """
if self._max_samples is None:
if self.has_irregular_rate:
self._max_samples = int(self.max_buffer_len)
else:
self._max_samples = int(self.max_buffer_len *
- self.device_spec.sample_rate)
+ self.device_spec.sample_rate) # type: ignore
return self._max_samples
def start_acquisition(self) -> bool:
- """Connect to the datasource and start acquiring data.
+ """Connects to the data source and begins acquiring data.
+
+ Initializes the LSL `StreamInlet` and optionally an `LslRecordingThread`
+ if a `save_directory` is provided.
- Returns
- -------
- bool : False if acquisition is already in progress, otherwise True.
+ Returns:
+ bool: False if acquisition is already in progress, otherwise True.
"""
if self.inlet:
return False
@@ -128,7 +173,7 @@ def start_acquisition(self) -> bool:
self.device_spec = device_from_metadata(self.inlet.info())
if self.save_directory:
- msg_queue = Queue()
+ msg_queue: Queue[float] = Queue()
self.recorder = LslRecordingThread(
directory=self.save_directory,
filename=self.raw_data_file_name,
@@ -151,9 +196,15 @@ def start_acquisition(self) -> bool:
return True
def stop_acquisition(self) -> None:
- """Disconnect from the data source."""
- logger.info(f"Stopping Acquisition from {self.device_spec.name} ...")
+ """Disconnects from the data source and cleans up resources.
+
+ Stops the `LslRecordingThread` if active, closes the LSL `StreamInlet`,
+ and clears the internal buffer.
+ """
+ logger.info(
+ f"Stopping Acquisition from {self.device_spec.name} ...") # type: ignore
if self.recorder:
+ # type: ignore
logger.info(f"Closing {self.device_spec.name} data recorder")
self.recorder.stop()
self.recorder.join()
@@ -165,24 +216,44 @@ def stop_acquisition(self) -> None:
self.buffer = None
- def __enter__(self):
- """Context manager enter method that starts data acquisition."""
+ def __enter__(self) -> 'LslAcquisitionClient':
+ """Context manager enter method that starts data acquisition.
+
+ Returns:
+ LslAcquisitionClient: The instance of the acquisition client.
+ """
self.start_acquisition()
return self
- def __exit__(self, _exc_type, _exc_value, _traceback):
- """Context manager exit method to clean up resources."""
+ def __exit__(self, _exc_type: Any, _exc_value: Any, _traceback: Any) -> None:
+ """Context manager exit method to clean up resources.
+
+ Args:
+ _exc_type (Any): The exception type, if an exception was raised.
+ _exc_value (Any): The exception value, if an exception was raised.
+ _traceback (Any): The traceback, if an exception was raised.
+ """
self.stop_acquisition()
def _data_stats(self, data: List[Record]) -> Dict[str, float]:
- """Summarize a list of records for logging and inspection."""
+ """Summarizes a list of records for logging and inspection.
+
+ Args:
+ data (List[Record]): A list of `Record` objects.
+
+ Returns:
+ Dict[str, float]: A dictionary containing statistics such as count,
+ total seconds, start/end timestamps, expected difference,
+ mean difference, and max difference between samples.
+ Returns an empty dict if `data` is empty.
+ """
if data:
diffs = pd.DataFrame(data)['timestamp'].diff()
data_start = data[0].timestamp
data_end = data[-1].timestamp
precision = 3
expected_diff = 0.0 if self.has_irregular_rate else round(
- 1 / self.device_spec.sample_rate, precision)
+ 1 / self.device_spec.sample_rate, precision) # type: ignore
return {
'count': len(data),
'seconds': round(data_end - data_start, precision),
@@ -198,20 +269,28 @@ def get_data(self,
start: Optional[float] = None,
end: Optional[float] = None,
limit: Optional[int] = None) -> List[Record]:
- """Get data in time range.
+ """Retrieves data within a specified time range from the current buffer.
- Only data in the current buffer is available to query;
- requests for data outside of this will fail.
+ Only data currently in the buffer is available for querying.
+ Requests for data outside of this range will fail.
- Parameters
- ----------
- start : starting timestamp (acquisition clock).
- end : end timestamp (in acquisition clock).
- limit: the max number of records that should be returned.
+ Args:
+ start (Optional[float]): The starting timestamp (in acquisition clock).
+ Defaults to None, which means the beginning
+ of available data.
+ end (Optional[float]): The end timestamp (in acquisition clock).
+ Defaults to None, which means the end of
+ available data.
+ limit (Optional[int]): The maximum number of records to return.
+ Defaults to None, which means no limit.
- Returns
- -------
- List of Records
+ Returns:
+ List[Record]: A list of `Record` objects within the specified range.
+ Returns an empty list if no records are available.
+
+ Raises:
+ AssertionError: If `start` or `end` times are out of the available data range
+ for regular rate devices.
"""
logger.info(request_desc(start, end, limit))
@@ -230,21 +309,26 @@ def get_data(self,
end = data_end
if not self.has_irregular_rate:
- assert start >= data_start, 'Start time out of range'
- assert end <= data_end, 'End time out of range'
+ assert start is not None and start >= data_start, 'Start time out of range'
+ assert end is not None and end <= data_end, 'End time out of range'
data_slice = [
- record for record in data if start <= record.timestamp <= end
+ record for record in data if start <= record.timestamp <= end # type: ignore
][0:limit]
logger.info(f"Filtered records: {self._data_stats(data_slice)}")
return data_slice
def get_latest_data(self) -> List[Record]:
- """Add all available samples in the inlet to the buffer.
+ """Adds all currently available samples from the inlet to the buffer.
The number of items returned depends on the size of the configured
- max_buffer_len and the amount of data available in the inlet."""
+ `max_buffer_len` and the amount of data available in the inlet.
+
+ Returns:
+ List[Record]: A list of the latest `Record` objects from the buffer.
+ Returns an empty list if no buffer is initialized.
+ """
if not self.buffer:
return []
@@ -256,76 +340,101 @@ def get_latest_data(self) -> List[Record]:
return self.buffer.get()
def _pull_chunk(self) -> int:
- """Pull a chunk of samples from LSL and record in the buffer.
- Returns the count of samples pulled.
+ """Pulls a chunk of samples from LSL and records them in the buffer.
+
+ Returns:
+ int: The count of samples pulled in this operation.
"""
logger.debug(f"\tPulling chunk (max_samples: {self.max_samples})")
# A timeout of 0.0 gets currently available samples without blocking.
- samples, timestamps = self.inlet.pull_chunk(
+ samples, timestamps = self.inlet.pull_chunk( # type: ignore
timeout=0.0, max_samples=self.max_samples)
count = len(samples)
- logger.debug(f"\t-> received {count} samples: {time_range(timestamps)}")
+ logger.debug(
+ f"\t-> received {count} samples: {time_range(timestamps)}")
for sample, stamp in zip(samples, timestamps):
- self.buffer.append(Record(sample, stamp))
+ self.buffer.append(Record(sample, stamp)) # type: ignore
return count
def convert_time(self, experiment_clock: Clock, timestamp: float) -> float:
- """
- Convert a timestamp from the experiment clock to the acquisition clock.
+ """Converts a timestamp from the experiment clock to the acquisition clock.
+
Used for querying the acquisition data for a time slice.
- Parameters:
- ----------
- - experiment_clock : clock used to generate the timestamp
- - timestamp : timestamp from the experiment clock
+ Args:
+ experiment_clock (Clock): The clock used to generate the timestamp.
+ timestamp (float): The timestamp from the experiment clock.
Returns:
- --------
- corresponding timestamp for the acquistion clock
+ float: The corresponding timestamp for the acquisition clock.
"""
# experiment_time = pylsl.local_clock() - offset
return timestamp + self.clock_offset(experiment_clock)
def get_data_seconds(self, seconds: int) -> List[Record]:
- """Returns the last n second of data"""
+ """Returns the last 'n' seconds of data available in the buffer.
+
+ Args:
+ seconds (int): The number of seconds of data to retrieve.
+
+ Returns:
+ List[Record]: A list of `Record` objects covering the last `seconds`.
+
+ Raises:
+ AssertionError: If `seconds` exceeds `max_buffer_len`.
+ """
assert seconds <= self.max_buffer_len, f"Seconds can't exceed {self.max_buffer_len}"
- sample_count = seconds * self.device_spec.sample_rate
+ sample_count = seconds * self.device_spec.sample_rate # type: ignore
records = self.get_latest_data()
start_index = 0 if len(
records) > sample_count else len(records) - sample_count
- return records[start_index:]
+ return records[int(start_index):]
@property
- def is_calibrated(self):
- """Returns boolean indicating whether or not acquisition has been
- calibrated (an offset calculated based on a trigger)."""
+ def is_calibrated(self) -> bool:
+ """Checks whether acquisition has been calibrated (an offset calculated based on a trigger).
+
+ Returns:
+ bool: True, as this property is currently hardcoded to return True.
+ """
return True
@is_calibrated.setter
- def is_calibrated(self, bool_val):
- """Setter for the is_calibrated property that allows the user to
- override the calculated value and use a 0 offset.
-
- Parameters
- ----------
- - bool_val : boolean
- if True, uses a 0 offset; if False forces the calculation.
+ def is_calibrated(self, bool_val: bool) -> None:
+ """Setter for the `is_calibrated` property.
+
+ Allows the user to override the calculated value and use a 0 offset.
+
+ Args:
+ bool_val (bool): If True, forces a 0 offset; if False, forces the calculation.
+ Note: Current implementation always returns True for getter.
"""
+ # This setter currently has no effect on the getter's return value.
+ pass
def clock_offset(self, experiment_clock: Optional[Clock] = None) -> float:
- """
- Offset in seconds from the experiment clock to the acquisition local clock.
+ """Calculates the offset in seconds from the experiment clock to the acquisition local clock.
+
+ The experiment clock should be monotonic from experiment start time. The
+ acquisition clock (`pylsl.local_clock()`) is monotonic from local machine
+ start time (or since 1970-01-01 00:00). Therefore the acquisition clock
+ should always be greater than experiment clock.
+
+ Args:
+ experiment_clock (Optional[Clock], optional): The experiment clock object.
+ Defaults to None, in which
+ case `self.experiment_clock` is used.
- The experiment clock should be monotonic from experiment start time.
- The acquisition clock (pylsl.local_clock()) is monotonic from local
- machine start time (or since 1970-01-01 00:00). Therefore the acquisition
- clock should always be greater than experiment clock. An exception is
- raised if this doesn't hold.
+ Returns:
+ float: The offset in seconds.
- See https://labstreaminglayer.readthedocs.io/info/faqs.html#lsl-local-clock
+ Raises:
+ AssertionError: If an experiment clock is not provided or available.
+ InvalidClockError: If the acquisition clock is not greater than the
+ experiment clock.
"""
clock = experiment_clock or self.experiment_clock
assert clock, "An experiment clock must be provided"
@@ -338,17 +447,14 @@ def clock_offset(self, experiment_clock: Optional[Clock] = None) -> float:
return diff
def event_offset(self, event_clock: Clock, event_time: float) -> float:
- """Compute number of seconds that recording started prior to the given
- event.
+ """Computes the number of seconds that recording started prior to the given event.
- Parameters
- ----------
- - event_clock : monotonic clock used to record the event time.
- - event_time : timestamp of the event of interest.
+ Args:
+ event_clock (Clock): Monotonic clock used to record the event time.
+ event_time (float): Timestamp of the event of interest.
- Returns
- -------
- Seconds between acquisition start and the event.
+ Returns:
+ float: Seconds between acquisition start and the event, or 0.0 if `first_sample_time` is not set.
"""
if self.first_sample_time:
lsl_event_time = self.convert_time(event_clock, event_time)
@@ -356,19 +462,18 @@ def event_offset(self, event_clock: Clock, event_time: float) -> float:
return 0.0
def offset(self, first_stim_time: float) -> float:
- """Offset in seconds from the start of acquisition to the given stim
- time.
+ """Calculates the offset in seconds from the start of acquisition to the given stimulus time.
- Parameters
- ----------
- - first_stim_time : LSL local clock timestamp of the first stimulus.
+ Args:
+ first_stim_time (float): LSL local clock timestamp of the first stimulus.
- Returns
- -------
- The number of seconds between acquisition start and the calibration
- event, or 0.0 .
- """
+ Returns:
+ float: The number of seconds between acquisition start and the calibration
+ event, or 0.0 if `first_stim_time` is zero or `has_irregular_rate` is True.
+ Raises:
+ AssertionError: If `first_sample_time` is not set and `has_irregular_rate` is False.
+ """
if not first_stim_time or self.has_irregular_rate:
return 0.0
assert self.first_sample_time, "Acquisition was not started."
@@ -376,13 +481,27 @@ def offset(self, first_stim_time: float) -> float:
logger.info(f"Acquisition offset: {offset_from_stim}")
return offset_from_stim
- def cleanup(self):
- """Perform any necessary cleanup."""
+ def cleanup(self) -> None:
+ """Performs any necessary cleanup tasks.
+
+ Currently, this method is a placeholder and does not perform any specific actions.
+ """
+ pass
def discover_device_spec(content_type: str) -> DeviceSpec:
"""Finds the first LSL stream with the given content type and creates a
- device spec from the stream's metadata."""
+ device spec from the stream's metadata.
+
+ Args:
+ content_type (str): The content type of the LSL stream to discover (e.g., "EEG", "Markers").
+
+ Returns:
+ DeviceSpec: A `DeviceSpec` object created from the metadata of the discovered stream.
+
+ Raises:
+ Exception: If an LSL stream with the specified content type is not found within `LSL_TIMEOUT`.
+ """
logger.info(f"Waiting for {content_type} data to be streamed over LSL.")
streams = resolve_byprop('type', content_type, timeout=LSL_TIMEOUT)
if not streams:
diff --git a/bcipy/acquisition/protocols/lsl/lsl_connector.py b/bcipy/acquisition/protocols/lsl/lsl_connector.py
index 31363db38..4c2002e8a 100644
--- a/bcipy/acquisition/protocols/lsl/lsl_connector.py
+++ b/bcipy/acquisition/protocols/lsl/lsl_connector.py
@@ -1,8 +1,7 @@
-# pylint: disable=fixme
-"""Defines the driver for the Device for communicating with
-LabStreamingLayer (LSL)."""
+"""Defines the driver for the Device for communicating with LabStreamingLayer (LSL)."""
+
import logging
-from typing import Dict, List
+from typing import Any, Dict, List, Optional, Tuple
import pylsl
@@ -15,43 +14,84 @@
LSL_TIMEOUT_SECONDS = 5.0
-class Marker():
- """Data class which wraps a LSL marker; data pulled from a marker stream is
- a tuple where the first item is a list of channels and second item is the
- timestamp. Assumes that marker inlet only has a single channel."""
+class Marker:
+ """Data class which wraps an LSL marker.
+
+ Data pulled from a marker stream is a tuple where the first item is a list
+ of channels (typically one) and the second item is the timestamp.
+ Assumes that the marker inlet only has a single channel.
- def __init__(self, data=(None, None)):
- super(Marker, self).__init__()
- self.channels, self.timestamp = data
+ Args:
+ data (Tuple[Optional[List[Any]], Optional[float]], optional): A tuple
+ containing the channel data (list) and its timestamp (float).
+ Defaults to (None, None).
+ """
+
+ def __init__(self, data: Tuple[Optional[List[Any]], Optional[float]] = (None, None)):
+ super().__init__()
+ self.channels: Optional[List[Any]] = data[0]
+ self.timestamp: Optional[float] = data[1]
@classmethod
- def empty(cls):
- """Creates an empty Marker."""
+ def empty(cls) -> 'Marker':
+ """Creates an empty Marker instance.
+
+ Returns:
+ Marker: An empty Marker object.
+ """
return Marker()
- def __repr__(self):
+ def __repr__(self) -> str:
+ """Returns a string representation of the Marker object."""
return f""
@property
- def is_empty(self):
- """Test to see if the current marker is empty."""
+ def is_empty(self) -> bool:
+ """Checks if the current marker is empty.
+
+ Returns:
+ bool: True if both channels and timestamp are None, False otherwise.
+ """
return self.channels is None or self.timestamp is None
@property
- def trg(self):
- """Get the trigger."""
+ def trg(self) -> Optional[Any]:
+ """Gets the trigger value from the marker's channels.
+
+ Assumes the trigger is the first element in the channels list.
+
+ Returns:
+ Optional[Any]: The trigger value, or None if channels is empty or None.
+ """
# pylint: disable=unsubscriptable-object
return self.channels[0] if self.channels else None
-def inlet_name(inlet) -> str:
- """Returns the name of a pylsl streamInlet."""
+def inlet_name(inlet: pylsl.StreamInlet) -> str:
+ """Returns a sanitized name of a pylsl `StreamInlet`.
+
+ Converts the stream info name by replacing spaces and hyphens with underscores.
+
+ Args:
+ inlet (pylsl.StreamInlet): The LSL StreamInlet object.
+
+ Returns:
+ str: The sanitized name of the inlet.
+ """
name = '_'.join(inlet.info().name().split())
return name.replace('-', '_')
def channel_names(stream_info: pylsl.StreamInfo) -> List[str]:
- """Extracts the channel names from the LSL Stream metadata."""
+ """Extracts the channel names from the LSL Stream metadata.
+
+ Args:
+ stream_info (pylsl.StreamInfo): The LSL `StreamInfo` object.
+
+ Returns:
+ List[str]: A list of channel names. If the stream type is 'Markers',
+ it returns `['Marker']`.
+ """
channels: List[str] = []
if stream_info.type() == 'Markers':
return ['Marker']
@@ -67,9 +107,19 @@ def channel_names(stream_info: pylsl.StreamInfo) -> List[str]:
return channels
-def check_device(device_spec: DeviceSpec, metadata: pylsl.StreamInfo):
- """Confirm that the properties of the given device_spec match the metadata
- acquired from the device."""
+def check_device(device_spec: DeviceSpec, metadata: pylsl.StreamInfo) -> None:
+ """Confirms that the properties of the given `DeviceSpec` match the metadata
+ acquired from the LSL stream.
+
+ Args:
+ device_spec (DeviceSpec): The expected `DeviceSpec` for the device.
+ metadata (pylsl.StreamInfo): The LSL `StreamInfo` object containing the
+ actual device metadata.
+
+ Raises:
+ Exception: If channel names, channel count, or sample rate do not match
+ between `device_spec` and `metadata`.
+ """
channels = channel_names(metadata)
# Confirm that provided channels match metadata, or meta is empty.
if channels and device_spec.channel_names != channels:
@@ -77,8 +127,7 @@ def check_device(device_spec: DeviceSpec, metadata: pylsl.StreamInfo):
print(device_spec.channel_names)
raise Exception("Channels read from the device do not match "
"the provided parameters.")
- assert device_spec.channel_count == metadata.channel_count(
- ), "Channel count error"
+ assert device_spec.channel_count == metadata.channel_count(), "Channel count error"
if device_spec.sample_rate != metadata.nominal_srate():
raise Exception("Sample frequency read from device does not match "
@@ -86,11 +135,14 @@ def check_device(device_spec: DeviceSpec, metadata: pylsl.StreamInfo):
def rename_items(items: List[str], rules: Dict[str, str]) -> None:
- """Renames items based on the provided rules.
- Parameters
- ----------
- items - list of items ; values will be mutated
- rules - change key -> value
+ """Renames items in a list based on a provided mapping of rules.
+
+ The list of items is modified in place.
+
+ Args:
+ items (List[str]): A list of strings whose values may be mutated.
+ rules (Dict[str, str]): A dictionary where keys are original item names
+ and values are their new names.
"""
for key, val in rules.items():
if key in items:
diff --git a/bcipy/acquisition/protocols/lsl/lsl_recorder.py b/bcipy/acquisition/protocols/lsl/lsl_recorder.py
index d80d2bade..f4e633f77 100644
--- a/bcipy/acquisition/protocols/lsl/lsl_recorder.py
+++ b/bcipy/acquisition/protocols/lsl/lsl_recorder.py
@@ -3,7 +3,7 @@
import time
from multiprocessing import Queue
from pathlib import Path
-from typing import List, Optional
+from typing import Any, List, Optional
from pylsl import StreamInfo, StreamInlet, resolve_streams
@@ -23,22 +23,30 @@
class LslRecorder:
"""Records LSL data to a datastore. Resolves streams when started.
- Parameters:
- -----------
- - path : location to store the recordings
- - filenames : optional dict mapping device type to its raw data filename.
- Devices without an entry will use a naming convention.
+ Args:
+ path (str): Location to store the recordings.
+ filenames (Optional[dict], optional): Optional dictionary mapping device
+ type to its raw data filename.
+ Devices without an entry will use a
+ naming convention. Defaults to None.
"""
- streams: List['LslRecordingThread'] = None
+ streams: Optional[List['LslRecordingThread']] = None
def __init__(self, path: str, filenames: Optional[dict] = None) -> None:
super().__init__()
- self.path = path
- self.filenames = filenames or {}
+ self.path: str = path
+ self.filenames: dict = filenames or {}
def start(self) -> None:
- """Start recording all streams currently on the network."""
+ """Starts recording all LSL streams currently on the network.
+
+ This method creates an `LslRecordingThread` for each discovered stream
+ and starts them. It also validates that stream names are unique.
+
+ Raises:
+ Exception: If data stream names are not unique.
+ """
if not self.streams:
log.info("Recording data")
@@ -58,32 +66,40 @@ def start(self) -> None:
stream.start()
def stop(self, wait: bool = False) -> None:
- """Stop recording.
+ """Stops recording for all active streams.
- Parameters
- ----------
- - wait : if True waits for all threads to stop before returning.
+ Args:
+ wait (bool, optional): If True, waits for all recording threads
+ to stop before returning. Defaults to False.
"""
- for stream in self.streams:
- stream.stop()
- if wait:
- stream.join()
- self.streams = None
+ if self.streams:
+ for stream in self.streams:
+ stream.stop()
+ if wait:
+ stream.join()
+ self.streams = None
class LslRecordingThread(StoppableProcess):
"""Records data for the given LabStreamingLayer (LSL) data stream.
- Parameters:
- ----------
- - device_spec : DeviceSpec ; specifies the device from which to record.
- - directory : location to store the recording
- - filename : optional, name of the data file.
- - queue : optional multiprocessing queue; if provided the first_sample_time
- will be written here when available.
+ This class extends `StoppableProcess` to run recording in a separate process.
+
+ Args:
+ device_spec (DeviceSpec): Specifies the device from which to record.
+ directory (Optional[str], optional): Location to store the recording.
+ Defaults to '.'.
+ filename (Optional[str], optional): Optional name of the data file.
+ If None, a default filename based
+ on device properties will be used.
+ Defaults to None.
+ queue (Optional[Queue], optional): Optional multiprocessing queue.
+ If provided, the `first_sample_time`
+ will be written to this queue when available.
+ Defaults to None.
"""
- writer: RawDataWriter = None
+ writer: Optional[RawDataWriter] = None
def __init__(self,
device_spec: DeviceSpec,
@@ -92,30 +108,41 @@ def __init__(self,
queue: Optional[Queue] = None) -> None:
super().__init__()
- self.directory = directory
- self.device_spec = device_spec
- self.queue = queue
+ self.directory: Optional[str] = directory
+ self.device_spec: DeviceSpec = device_spec
+ self.queue: Optional[Queue] = queue
- self.sample_count = 0
+ self.sample_count: int = 0
# see: https://labstreaminglayer.readthedocs.io/info/faqs.html#chunk-sizes
- self.max_chunk_size = 1024
+ self.max_chunk_size: int = 1024
# seconds to sleep between data pulls from LSL
- self.sleep_seconds = 0.2
+ self.sleep_seconds: float = 0.2
- self.filename = filename if filename else self.default_filename()
- self.first_sample_time = None
- self.last_sample_time = None
+ self.filename: str = filename if filename else self.default_filename()
+ self.first_sample_time: Optional[float] = None
+ self.last_sample_time: Optional[float] = None
+
+ def default_filename(self) -> str:
+ """Generates a default filename to use if a name is not provided.
+
+ The filename is based on the device's content type and name.
- def default_filename(self):
- """Default filename to use if a name is not provided."""
+ Returns:
+ str: The generated default filename (e.g., "eeg_data_dsi_24.csv").
+ """
content_type = '_'.join(self.device_spec.content_type.split()).lower()
name = '_'.join(self.device_spec.name.split()).lower()
return f"{content_type}_data_{name}.csv"
@property
def recorded_seconds(self) -> float:
- """Total seconds of data recorded."""
+ """Calculates the total seconds of data recorded.
+
+ Returns:
+ float: The duration of recorded data in seconds, or 0.0 if recording
+ hasn't started or completed.
+ """
if self.first_sample_time and self.last_sample_time:
return self.last_sample_time - self.first_sample_time
return 0.0
@@ -123,9 +150,11 @@ def recorded_seconds(self) -> float:
def _init_data_writer(self, stream_info: StreamInfo) -> None:
"""Initializes the raw data writer.
- Parameters:
- ----------
- - metadata : metadata about the data stream.
+ Args:
+ stream_info (StreamInfo): Metadata about the data stream.
+
+ Raises:
+ AssertionError: If the data writer has already been initialized.
"""
assert self.writer is None, "Data store has already been initialized."
@@ -135,7 +164,7 @@ def _init_data_writer(self, stream_info: StreamInfo) -> None:
check_device(self.device_spec, stream_info)
channels = self.device_spec.channels
- path = str(Path(self.directory, self.filename))
+ path = str(Path(self.directory, self.filename)) # type: ignore
log.info(f"Writing data to {path}")
self.writer = RawDataWriter(
path,
@@ -145,37 +174,38 @@ def _init_data_writer(self, stream_info: StreamInfo) -> None:
self.writer.__enter__()
def _cleanup(self) -> None:
- """Performs cleanup tasks."""
+ """Performs cleanup tasks for the data writer.
+
+ Closes the `RawDataWriter` if it was initialized.
+ """
if self.writer:
self.writer.__exit__()
self.writer = None
- def _write_chunk(self, data: List, timestamps: List) -> None:
- """Persists the data resulting from pulling a chunk from the inlet.
+ def _write_chunk(self, data: List[List[Any]], timestamps: List[float]) -> None:
+ """Persists a chunk of data pulled from the LSL inlet.
- Parameters
- ----------
- data : list of samples
- timestamps : list of timestamps
+ Args:
+ data (List[List[Any]]): A list of samples, where each sample is a list of channel values.
+ timestamps (List[float]): A list of timestamps corresponding to each sample.
"""
assert self.writer, "Writer not initialized"
- chunk = []
+ chunk: List[List[Any]] = []
for i, sample in enumerate(data):
self.sample_count += 1
chunk.append([self.sample_count] + sample + [timestamps[i]])
self.writer.writerows(chunk)
def _pull_chunk(self, inlet: StreamInlet) -> int:
- """Pull a chunk of data and persist. Updates first_sample_time,
- last_sample_time, and sample_count.
+ """Pulls a chunk of data from the `StreamInlet` and persists it.
+
+ Updates `first_sample_time`, `last_sample_time`, and `sample_count`.
- Parameters
- ----------
- inlet : stream inlet from which to pull
+ Args:
+ inlet (StreamInlet): The LSL `StreamInlet` from which to pull data.
- Returns
- -------
- number of samples pulled
+ Returns:
+ int: The number of samples pulled in this operation.
"""
# A timeout of 0.0 does not block and only gets samples immediately
# available.
@@ -191,14 +221,20 @@ def _pull_chunk(self, inlet: StreamInlet) -> int:
return len(timestamps)
def _reset(self) -> None:
- """Reset state"""
+ """Resets the internal state of the recorder.
+
+ This includes resetting the sample count and clearing the first and last
+ sample timestamps.
+ """
self.sample_count = 0
self.first_sample_time = None
self.last_sample_time = None
# @override
- def run(self):
- """Process startup. Connects to the device, reads chunks of data at the
+ def run(self) -> None:
+ """Process startup and main recording loop.
+
+ Connects to the device, continuously reads chunks of data at the
given interval, and persists the results. This happens continuously
until the `stop()` method is called.
"""
@@ -236,9 +272,19 @@ def run(self):
self._cleanup()
-def main(path: str, seconds: int = 5, debug: bool = False):
- """Function to demo the LslRecorder. Expects LSL data streams to be already
- running."""
+def main(path: str, seconds: int = 5, debug: bool = False) -> None:
+ """Demonstrates the `LslRecorder` functionality.
+
+ This function initializes an `LslRecorder` and records data for a specified
+ duration. It expects LSL data streams to be already running.
+
+ Args:
+ path (str): The directory path to save the recorded data.
+ seconds (int, optional): The duration in seconds to record data.
+ Defaults to 5.
+ debug (bool, optional): If True, enables logging to stdout for debugging.
+ Defaults to False.
+ """
if debug:
log_to_stdout()
recorder = LslRecorder(path)
diff --git a/bcipy/acquisition/record.py b/bcipy/acquisition/record.py
index 8d93a5a9c..14b331264 100644
--- a/bcipy/acquisition/record.py
+++ b/bcipy/acquisition/record.py
@@ -4,9 +4,17 @@
class Record(NamedTuple):
- """Domain object used for storing data and timestamp
- information, where data is a single reading from a device and is a list
- of channel information (float)."""
+ """Domain object used for storing data and timestamp information.
+
+ The `data` attribute represents a single reading from a device and is a list
+ of channel information (typically float values).
+
+ Attributes:
+ data (List[Any]): A list of values representing channel information from a device.
+ timestamp (float): The timestamp associated with the data recording.
+ rownum (Optional[int], optional): The row number of the record, if applicable.
+ Defaults to None.
+ """
data: List[Any]
timestamp: float
rownum: Optional[int] = None
diff --git a/bcipy/acquisition/tests/datastream/test_generator.py b/bcipy/acquisition/tests/datastream/test_generator.py
index 82db8cc14..035e0a9f2 100644
--- a/bcipy/acquisition/tests/datastream/test_generator.py
+++ b/bcipy/acquisition/tests/datastream/test_generator.py
@@ -94,7 +94,8 @@ def test_file_generator_channel_count(self):
with patch('bcipy.acquisition.datastream.generator.open',
mock_open(read_data=test_data), create=True):
- gen = file_data_generator(filename='foo', header_row=1, channel_count=2)
+ gen = file_data_generator(
+ filename='foo', header_row=1, channel_count=2)
generated_data = [next(gen) for _ in range(row_count)]
for i, row in enumerate(generated_data):
@@ -163,7 +164,8 @@ def count_generator(low=0, high=10, step=1):
self.assertEqual(1, next(gen2))
self.assertEqual(3, next(gen1))
- new_rand_gen = generator_with_args(random_data_generator, channel_count=10)
+ new_rand_gen = generator_with_args(
+ random_data_generator, channel_count=10)
gen3 = new_rand_gen()
data = next(gen3)
self.assertEqual(10, len(data))
diff --git a/bcipy/acquisition/tests/test_devices.py b/bcipy/acquisition/tests/test_devices.py
index dc86980cd..0b63a0976 100644
--- a/bcipy/acquisition/tests/test_devices.py
+++ b/bcipy/acquisition/tests/test_devices.py
@@ -26,7 +26,7 @@ def test_default_supported_devices(self):
dsi = supported['DSI-24']
self.assertEqual('EEG', dsi.content_type)
- self.assertEqual(len(devices.with_content_type('EEG')), 4)
+ self.assertEqual(len(devices.with_content_type('EEG')), 5)
def test_load_from_config(self):
"""Should be able to load a list of supported devices from a
@@ -142,7 +142,7 @@ def test_device_spec_defaults(self):
"""DeviceSpec should require minimal information with default values."""
spec = devices.DeviceSpec(name='TestDevice',
channels=['C1', 'C2', 'C3'],
- sample_rate=256.0)
+ sample_rate=256)
self.assertEqual(3, spec.channel_count)
self.assertEqual('EEG', spec.content_type)
self.assertEqual(devices.DeviceStatus.ACTIVE, spec.status)
@@ -151,7 +151,7 @@ def test_device_spec_analysis_channels(self):
"""DeviceSpec should have a list of channels used for analysis."""
spec = devices.DeviceSpec(name='TestDevice',
channels=['C1', 'C2', 'C3', 'TRG'],
- sample_rate=256.0,
+ sample_rate=256,
excluded_from_analysis=['TRG'])
self.assertEqual(['C1', 'C2', 'C3'], spec.analysis_channels)
@@ -160,7 +160,7 @@ def test_device_spec_analysis_channels(self):
spec2 = devices.DeviceSpec(name='Device2',
channels=['C1', 'C2', 'C3', 'TRG'],
- sample_rate=256.0,
+ sample_rate=256,
excluded_from_analysis=['C1', 'TRG'])
self.assertEqual(['C2', 'C3'], spec2.analysis_channels)
@@ -178,7 +178,7 @@ def test_device_spec_analysis_channels(self):
'name': 'C4',
'label': 'TRG'
}],
- sample_rate=256.0,
+ sample_rate=256,
excluded_from_analysis=['ch1', 'TRG'])
self.assertEqual(['ch2', 'ch3'], spec3.analysis_channels)
@@ -192,7 +192,7 @@ def test_irregular_sample_rate(self):
with self.assertRaises(AssertionError):
devices.DeviceSpec(name='Mouse',
channels=['Btn1', 'Btn2'],
- sample_rate=-100.0,
+ sample_rate=-100,
content_type='Markers')
def test_data_type(self):
@@ -239,7 +239,7 @@ def test_device_spec_to_dict(self):
'type': None,
'units': None
}]
- sample_rate = 256.0
+ sample_rate = 256
content_type = 'EEG'
spec = devices.DeviceSpec(name=device_name,
channels=channels,
@@ -290,7 +290,7 @@ def test_load_static_offset(self):
content_type="EEG",
description="My Device",
channels=["a", "b", "c"],
- sample_rate=100.0,
+ sample_rate=100,
status=str(devices.DeviceStatus.PASSIVE),
static_offset=offset)
]
diff --git a/bcipy/config.py b/bcipy/config.py
index 9325e2447..ec31fd991 100644
--- a/bcipy/config.py
+++ b/bcipy/config.py
@@ -6,7 +6,8 @@
from pathlib import Path
DEFAULT_ENCODING = 'utf-8'
-DEFAULT_EVIDENCE_PRECISION = 5 # number of decimal places to round evidence to by default
+# number of decimal places to round evidence to by default
+DEFAULT_EVIDENCE_PRECISION = 5
MARKER_STREAM_NAME = 'TRG_device_stream'
DEFAULT_TRIGGER_CHANNEL_NAME = 'TRG'
DIODE_TRIGGER = '\u25A0'
@@ -25,7 +26,7 @@
DEFAULT_EXPERIMENT_PATH = f'{BCIPY_ROOT}/parameters/experiment'
DEFAULT_FIELD_PATH = f'{BCIPY_ROOT}/parameters/field'
DEFAULT_USER_ID = 'test_user'
-TASK_SEPERATOR = '->'
+TASK_SEPARATOR = '->'
DEFAULT_PARAMETERS_FILENAME = 'parameters.json'
DEFAULT_DEVICES_PATH = f"{BCIPY_ROOT}/parameters"
diff --git a/bcipy/core/demo/demo_report.py b/bcipy/core/demo/demo_report.py
index c2d3bb4af..b4dfe46dc 100644
--- a/bcipy/core/demo/demo_report.py
+++ b/bcipy/core/demo/demo_report.py
@@ -40,14 +40,16 @@
# loop through the sessions, pausing after each one to allow for manual stopping
if session.is_dir():
print(f'Processing {session}')
- prompt = input('Hit enter to continue or type "skip" to skip processing: ')
+ prompt = input(
+ 'Hit enter to continue or type "skip" to skip processing: ')
if prompt != 'skip':
# load the parameters from the data directory
parameters = load_json_parameters(
f'{session}/{DEFAULT_PARAMETERS_FILENAME}', value_cast=True)
# load the raw data from the data directory
- raw_data = load_raw_data(Path(session, f'{RAW_DATA_FILENAME}.csv'))
+ raw_data = load_raw_data(
+ Path(session, f'{RAW_DATA_FILENAME}.csv'))
type_amp = raw_data.daq_type
channels = raw_data.channels
sample_rate = raw_data.sample_rate
@@ -72,7 +74,8 @@
trigger_type, trigger_timing, trigger_label = trigger_decoder(
offset=parameters.get('static_trigger_offset'),
trigger_path=f"{session}/{TRIGGER_FILENAME}",
- exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION],
+ exclusion=[TriggerType.PREVIEW,
+ TriggerType.EVENT, TriggerType.FIXATION],
)
triggers = (trigger_type, trigger_timing, trigger_label)
else:
diff --git a/bcipy/core/demo/demo_session_tools.py b/bcipy/core/demo/demo_session_tools.py
index 97b9bbcdd..daa09f337 100644
--- a/bcipy/core/demo/demo_session_tools.py
+++ b/bcipy/core/demo/demo_session_tools.py
@@ -48,6 +48,7 @@ def main(data_dir: str):
if args.csv:
session_csv(session, csv_file=str(Path(path, "session.csv")))
if args.charts:
- session_excel(session, excel_file=str(Path(path, SESSION_SUMMARY_FILENAME)))
+ session_excel(session, excel_file=str(
+ Path(path, SESSION_SUMMARY_FILENAME)))
else:
main(path)
diff --git a/bcipy/core/list.py b/bcipy/core/list.py
index 64c643617..8bfeae51a 100644
--- a/bcipy/core/list.py
+++ b/bcipy/core/list.py
@@ -1,16 +1,20 @@
"""Utility functions for list processing."""
from itertools import zip_longest
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
def destutter(items: List[Any], key: Callable = lambda x: x) -> List:
"""Removes sequential duplicates from a list. Retains the last item in the
- sequence. Equality is determined using the provided key function.
+ sequence.
- Parameters
- ----------
- items - list of items with sequential duplicates
- key - equality function
+ Equality is determined using the provided key function.
+
+ Args:
+ items (List[Any]): List of items with sequential duplicates.
+ key (Callable, optional): Equality function. Defaults to `lambda x: x`.
+
+ Returns:
+ List: A new list with sequential duplicates removed.
"""
deduped: List[Any] = []
for item in items:
@@ -21,10 +25,29 @@ def destutter(items: List[Any], key: Callable = lambda x: x) -> List:
return deduped
-def grouper(iterable, chunk_size, incomplete="fill", fillvalue=None):
- "Collect data into non-overlapping fixed-length chunks or blocks"
- # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
- # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
+def grouper(iterable: Any, chunk_size: int, incomplete: str = "fill",
+ fillvalue: Optional[Any] = None) -> Union[Iterator[Tuple], Iterator[Any]]:
+ """Collect data into non-overlapping fixed-length chunks or blocks.
+
+ Args:
+ iterable (Any): The iterable to group.
+ chunk_size (int): The size of each chunk.
+ incomplete (str, optional): Strategy for incomplete chunks. Can be "fill" or "ignore".
+ Defaults to "fill".
+ fillvalue (Optional[Any], optional): Value to fill incomplete chunks with if `incomplete` is "fill".
+ Defaults to None.
+
+ Returns:
+ Union[Iterator[Tuple], Iterator[Any]]: An iterator yielding chunks.
+
+ Raises:
+ ValueError: If `fillvalue` is not defined when `incomplete` is "fill", or if `incomplete`
+ is neither "fill" nor "ignore".
+
+ Examples:
+ grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
+ grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
+ """
chunks = [iter(iterable)] * chunk_size
if incomplete == "fill":
if fillvalue:
@@ -39,7 +62,16 @@ def grouper(iterable, chunk_size, incomplete="fill", fillvalue=None):
def find_index(iterable: List,
match_item: Union[Any, Callable],
key: Callable = lambda x: x) -> Optional[int]:
- """Find the index of the first item in the iterable which matches."""
+ """Find the index of the first item in the iterable which matches.
+
+ Args:
+ iterable (List): The list to search through.
+ match_item (Union[Any, Callable]): The item to match or a callable to apply to each item.
+ key (Callable, optional): A function to apply to each item before comparison. Defaults to `lambda x: x`.
+
+ Returns:
+ Optional[int]: The index of the first matching item, or None if no match is found.
+ """
for i, value in enumerate(iterable):
if callable(match_item):
result = match_item(value)
@@ -51,8 +83,16 @@ def find_index(iterable: List,
def swapped(lst: List[Any], index1: int, index2: int) -> List[Any]:
- """Creates a copy of the provided list with elements at the given indices
- swapped."""
+ """Creates a copy of the provided list with elements at the given indices swapped.
+
+ Args:
+ lst (List[Any]): The original list.
+ index1 (int): The index of the first element to swap.
+ index2 (int): The index of the second element to swap.
+
+ Returns:
+ List[Any]: A new list with the elements at `index1` and `index2` swapped.
+ """
replacements = {index1: lst[index2], index2: lst[index1]}
return [replacements.get(i, val) for i, val in enumerate(lst)]
@@ -60,18 +100,22 @@ def swapped(lst: List[Any], index1: int, index2: int) -> List[Any]:
def expanded(lst: List[Any],
length: int,
fill: Union[Any, Callable] = lambda x: x[-1]) -> List[Any]:
- """Creates a copy of the provided list expanded to the given length. By
- default the last item is used as the fill item.
-
- Parameters
- ----------
- lst - list of items to copy
- length - expands list to this length
- fill - optional; used to determine which element to use for
+ """Creates a copy of the provided list expanded to the given length.
+
+ By default, the last item is used as the fill item.
+
+ Args:
+ lst (List[Any]): List of items to copy.
+ length (int): The target length to expand the list to.
+ fill (Union[Any, Callable], optional): Used to determine which element to use for
the fill, given the list. Defaults to the last element.
- >>> expand([1,2,3], length=5)
- [1,2,3,3,3]
+ Returns:
+ List[Any]: The expanded list.
+
+ Examples:
+ >>> expanded([1,2,3], length=5)
+ [1,2,3,3,3]
"""
times = length - len(lst)
if lst and times > 0:
@@ -80,10 +124,18 @@ def expanded(lst: List[Any],
return lst
-def pairwise(iterable):
- """
- pairwise('ABCDEFG') → AB BC CD DE EF FG
- https://docs.python.org/3/library/itertools.html#itertools.pairwise
+def pairwise(iterable: Any) -> Iterator[Tuple]:
+ """Returns an iterator over overlapping pairs from the input iterable.
+
+ Args:
+ iterable (Any): The iterable to process.
+
+ Yields:
+ Tuple: A tuple containing two consecutive elements from the iterable.
+
+ Examples:
+ pairwise('ABCDEFG') → AB BC CD DE EF FG
+ https://docs.python.org/3/library/itertools.html#itertools.pairwise
"""
iterator = iter(iterable)
a = next(iterator, None)
diff --git a/bcipy/core/parameters.py b/bcipy/core/parameters.py
index dbcf03009..230b66960 100644
--- a/bcipy/core/parameters.py
+++ b/bcipy/core/parameters.py
@@ -4,13 +4,23 @@
from json import dump, load
from pathlib import Path
from re import fullmatch
-from typing import Any, Dict, NamedTuple, Optional, Tuple
+from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from bcipy.config import DEFAULT_ENCODING, DEFAULT_PARAMETERS_PATH
class Parameter(NamedTuple):
- """Represents a single parameter"""
+ """Represents a single parameter."
+
+ Attributes:
+ value (Any): The value of the parameter.
+ section (str): The section the parameter belongs to.
+ name (str): The display name of the parameter.
+ helpTip (str): A helpful tip or description for the parameter.
+ recommended (list): Recommended values for the parameter.
+ editable (bool): Whether the parameter is editable.
+ type (str): The data type of the parameter (e.g., 'int', 'float', 'bool', 'str', 'range').
+ """
value: Any
section: str
name: str
@@ -21,41 +31,62 @@ class Parameter(NamedTuple):
class ParameterChange(NamedTuple):
- """Represents a Parameter that has been modified from a different value."""
- parameter: Parameter
+ """Represents a Parameter that has been modified from a different value."
+
+ Attributes:
+ parameter (Parameter): The modified parameter.
+ original_value (Any): The original value of the parameter before modification.
+ """
+ parameter: Union[Parameter, dict]
original_value: Any
-def parse_range(range_str: str) -> Tuple:
- """Parse the range description into a tuple of (low, high).
+def parse_range(range_str: str) -> Tuple[Union[int, float], Union[int, float]]:
+ """Parses the range description into a tuple of (low, high).
+
+ If either value can be parsed as a float, the resulting tuple will have
+ float values; otherwise, they will be integers.
- If either value can be parsed as a float the resulting tuple will have
- float values, otherwise they will be ints.
+ Args:
+ range_str (str): Range description formatted as 'low:high'.
- Parameters
- ----------
- range_str - range description formatted 'low:high'
+ Returns:
+ Tuple[Union[int, float], Union[int, float]]: A tuple containing the low and high values of the range.
- >>> parse_range("1:10")
- (1, 10)
+ Raises:
+ AssertionError: If the `range_str` is not in the format 'low:high' or if the low value is not less than the high value.
+
+ Examples:
+ >>> parse_range("1:10")
+ (1, 10)
"""
assert ':' in range_str, "Invalid range format; values must be separated by ':'"
- low, high = range_str.split(':')
+ low_str, high_str = range_str.split(':')
int_pattern = "-?\\d+"
- if fullmatch(int_pattern, low) and fullmatch(int_pattern, high):
- low = int(low)
- high = int(high)
+ if fullmatch(int_pattern, low_str) and fullmatch(int_pattern, high_str):
+ low: Union[int, float] = int(low_str)
+ high: Union[int, float] = int(high_str)
else:
- low = float(low)
- high = float(high)
+ low = float(low_str)
+ high = float(high_str)
assert low < high, "Low value must be less that the high value"
return (low, high)
def serialize_value(value_type: str, value: Any) -> str:
- """Serialize the given value to a string. Serialized values should be able
- to be cast using the conversions."""
+ """Serializes the given value to a string.
+
+ Serialized values should be able to be cast using the `conversions` dictionary
+ defined in the `Parameters` class.
+
+ Args:
+ value_type (str): The declared type of the value (e.g., 'bool', 'range').
+ value (Any): The value to serialize.
+
+ Returns:
+ str: The serialized string representation of the value.
+ """
if value_type == 'bool':
return str(value).lower()
if value_type == 'range':
@@ -67,12 +98,16 @@ def serialize_value(value_type: str, value: Any) -> str:
class Parameters(dict):
"""Configuration parameters for BciPy.
- source: str - optional path to a JSON file. If file exists, data will be
- loaded from here. Raises an exception unless the entries are dicts with
- the required_keys.
+ This class extends `dict` to provide type-casting and validation for
+ configuration parameters, typically loaded from a JSON file.
- cast_values: bool - if True cast values to specified type; default is False.
- """
+ Args:
+ source (Optional[str], optional): Optional path to a JSON file. If the file exists,
+ data will be loaded from here. Raises an exception unless the entries are
+ dictionaries with the required keys. Defaults to None.
+ cast_values (bool, optional): If True, values will be cast to their specified type
+ when accessed. Defaults to False.
+ """
def __init__(self, source: Optional[str] = None, cast_values: bool = False):
super().__init__()
@@ -92,11 +127,19 @@ def __init__(self, source: Optional[str] = None, cast_values: bool = False):
self.load_from_source()
@classmethod
- def from_cast_values(cls, **kwargs):
- """Create a new Parameters object from cast values. This is useful
- primarily for testing
+ def from_cast_values(cls, **kwargs: Any) -> 'Parameters':
+ """Creates a new `Parameters` object from cast values.
+
+ This is useful primarily for testing.
+
+ Args:
+ **kwargs (Any): Keyword arguments representing parameter names and their values.
+
+ Returns:
+ Parameters: A new `Parameters` instance with values cast.
- >>> Parameters.from_cast_values(time_prompt=1.0, fake_data=True)
+ Examples:
+ >>> Parameters.from_cast_values(time_prompt=1.0, fake_data=True)
"""
params = Parameters(source=None, cast_values=True)
for key, val in kwargs.items():
@@ -118,29 +161,59 @@ def from_cast_values(cls, **kwargs):
return params
@property
- def supported_types(self) -> set:
- """Supported types for casting values"""
- return self.conversions.keys()
+ def supported_types(self) -> list:
+ """Returns the set of supported types for casting values."""
+ return list(self.conversions.keys())
- def cast_value(self, entry: dict) -> Any:
- """Takes an entry with a desired type and attempts to cast it to that type."""
+ def cast_value(self, entry: Dict[str, Any]) -> Any:
+ """Takes an entry with a desired type and attempts to cast it to that type.
+
+ Args:
+ entry (Dict[str, Any]): A dictionary representing a parameter entry with a 'type' key.
+
+ Returns:
+ Any: The value cast to the specified type.
+ """
cast = self.conversions[entry['type']]
- return cast(entry['value'])
+ return cast(entry['value']) # type: ignore
+
+ def serialized_value(self, value: Any, entry_type: str) -> str:
+ """Converts a value back into its serialized string form."
- def serialized_value(self, value, entry_type) -> str:
- """Convert a value back into its serialized form"""
+ Args:
+ value (Any): The value to serialize.
+ entry_type (str): The declared type of the value.
+
+ Returns:
+ str: The serialized string representation of the value.
+ """
serialized = str(value)
return serialized.lower() if entry_type == 'bool' else serialized
- def __getitem__(self, key) -> Any:
- """Override to handle cast values"""
+ def __getitem__(self, key: str) -> Any:
+ """Overrides dictionary item access to handle cast values."
+
+ Args:
+ key (str): The key of the parameter to retrieve.
+
+ Returns:
+ Any: The cast value of the parameter if `cast_values` is True, otherwise the raw entry dictionary.
+ """
entry = self.get_entry(key)
if self.cast_values:
return self.cast_value(entry)
return entry
- def __setitem__(self, key, value) -> None:
- """Override to handle cast values"""
+ def __setitem__(self, key: str, value: Any) -> None:
+ """Overrides dictionary item assignment to handle cast values and validate entries."
+
+ If `cast_values` is True, it attempts to set the serialized value for an existing entry.
+ Otherwise, it adds a new entry after validation.
+
+ Args:
+ key (str): The key of the parameter to set.
+ value (Any): The value to set for the parameter.
+ """
if self.cast_values:
# Can only set values for existing entries when cast.
entry = self.get_entry(key)
@@ -148,72 +221,119 @@ def __setitem__(self, key, value) -> None:
else:
self.add_entry(key, value)
- def add_entry(self, key, value) -> None:
- """Adds a configuration parameter."""
+ def add_entry(self, key: str, value: Dict[str, Any]) -> None:
+ """Adds a configuration parameter after validating its format."
+
+ Args:
+ key (str): The name of the configuration parameter.
+ value (Dict[str, Any]): A dictionary containing the parameter properties.
+ """
self.check_valid_entry(key, value)
super().__setitem__(key, value)
- def get_entry(self, key) -> dict:
- """Get the non-cast entry associated with the given key."""
+ def get_entry(self, key: str) -> Dict[str, Any]:
+ """Gets the non-cast entry associated with the given key."
+
+ Args:
+ key (str): The key of the parameter entry to retrieve.
+
+ Returns:
+ Dict[str, Any]: The raw dictionary entry for the parameter.
+ """
return super().__getitem__(key)
- def get(self, key, d=None) -> Any:
- """Override to handle cast values"""
+ def get(self, key: str, d: Optional[Any] = None) -> Any:
+ """Overrides dictionary `get` method to handle cast values."
+
+ Args:
+ key (str): The key of the parameter to retrieve.
+ d (Optional[Any], optional): Default value to return if the key is not found.
+ Defaults to None.
+
+ Returns:
+ Any: The cast value of the parameter if `cast_values` is True and the key is found,
+ otherwise the raw entry or the default value.
+ """
entry = super().get(key, d)
if self.cast_values and entry != d:
return self.cast_value(entry)
return entry
- def entries(self) -> list:
- """Uncast items"""
- return super().items()
+ def entries(self) -> List[Tuple[str, Dict[str, Any]]]:
+ """Returns the uncast items (key-value pairs) of the parameters.
- def items(self) -> list:
- """Override to handle cast values"""
+ Returns:
+ List[Tuple[str, Dict[str, Any]]]: A list of key-value tuples, where values are raw entry dictionaries.
+ """
+ return list(super().items()) # type: ignore
+
+ def items(self) -> List[Tuple[str, Any]]: # type: ignore
+ """Overrides dictionary `items` method to handle cast values.
+
+ Returns:
+ List[Tuple[str, Any]]: A list of key-value tuples, where values are cast if `cast_values` is True.
+ """
if self.cast_values:
return [(key, self.cast_value(entry))
for key, entry in self.entries()]
return self.entries()
- def values(self) -> list:
- """Override to handle cast values"""
+ def values(self) -> List[Any]: # type: ignore
+ """Override to handle cast values.
+
+ Returns:
+ List[Any]: A list of parameter values, cast if `cast_values` is True.
+ """
vals = super().values()
if self.cast_values:
return [self.cast_value(entry) for entry in vals]
- return vals
+ return list(vals) # type: ignore
- def update(self, *args, **kwargs) -> None:
- """Override to ensure update uses __setitem___"""
+ def update(self, *args: Any, **kwargs: Any) -> None:
+ """Overrides dictionary `update` method to ensure `__setitem__` is used for each item."
+
+ Args:
+ *args (Any): Positional arguments for dictionary update.
+ **kwargs (Any): Keyword arguments for dictionary update.
+ """
for key, value in dict(*args, **kwargs).items():
self[key] = value
def copy(self) -> 'Parameters':
- """Override
+ """Creates a shallow copy of the `Parameters` object."
+
+ Returns:
+ Parameters: A new `Parameters` instance with the same parameters.
"""
params = Parameters(source=None, cast_values=self.cast_values)
- params.load(super().copy())
+ params.load(super().copy()) # type: ignore
return params
- def load(self, data: dict) -> None:
- """Load values from a dict, validating entries (see check_valid_entry) and raising
- an exception for invalid values.
+ def load(self, data: Dict[str, Dict[str, Any]]) -> None:
+ """Loads values from a dictionary, validating entries and raising an exception for invalid values."
- data: dict of configuration parameters.
+ Args:
+ data (Dict[str, Dict[str, Any]]): A dictionary of configuration parameters.
"""
for name, entry in data.items():
self.add_entry(name, entry)
def load_from_source(self) -> None:
- """Load data from the configured JSON file."""
+ """Loads data from the configured JSON file."
+
+ If `self.source` is set, it attempts to open and load the JSON data from that path.
+ """
if self.source:
with codecsopen(self.source, 'r',
encoding=DEFAULT_ENCODING) as json_file:
data = load(json_file)
- self.load(data)
+ self.load(data) # type: ignore
- def check_valid_entry(self, entry_name: str, entry: dict) -> None:
- """Checks if the given entry is valid. Raises an exception unless the entry is formatted:
+ def check_valid_entry(self, entry_name: str, entry: Dict[str, Any]) -> None:
+ """Checks if the given entry is valid. Raises an exception unless the entry is formatted as expected."
+ Expected format:
+ ```json
"fake_data": {
"value": "true",
"section": "bci_config",
@@ -223,9 +343,16 @@ def check_valid_entry(self, entry_name: str, entry: dict) -> None:
"editable": true,
"type": "bool"
}
+ ```
+
+ Args:
+ entry_name (str): Name of the configuration parameter.
+ entry (Dict[str, Any]): Parameter properties.
- entry_name : str - name of the configuration parameter
- entry : dict - parameter properties
+ Raises:
+ AttributeError: If `entry` is not a dictionary.
+ Exception: If `entry` does not contain required keys, if the 'type' is not supported,
+ or if the 'value' for a 'bool' type is invalid.
"""
if not isinstance(entry, abc.Mapping):
raise AttributeError(f"'{entry_name}' value must be a dict")
@@ -242,10 +369,11 @@ def check_valid_entry(self, entry_name: str, entry: dict) -> None:
f"Invalid value for key: {entry_name}. Must be either 'true' or 'false'"
)
- def source_location(self) -> Tuple[Path, str]:
- """Location of the source json data if source was provided.
+ def source_location(self) -> Tuple[Optional[Path], Optional[str]]:
+ """Returns the location of the source JSON data if a source was provided."
- Returns Tuple(Path, filename: str)
+ Returns:
+ Tuple[Optional[Path], Optional[str]]: A tuple containing the parent directory path and the filename.
"""
if self.source:
path = Path(self.source)
@@ -253,12 +381,17 @@ def source_location(self) -> Tuple[Path, str]:
return (None, None)
def save(self, directory: Optional[str] = None, name: Optional[str] = None) -> str:
- """Save parameters to the given location
+ """Saves parameters to the given location."
+
+ Args:
+ directory (Optional[str], optional): Optional location to save the file. Defaults to the source directory.
+ name (Optional[str], optional): Optional name of the new parameters file. Defaults to the source filename.
- directory: str - optional location to save; default is the source_directory.
- name: str - optional name of new parameters file; default is the source filename.
+ Returns:
+ str: The path of the saved file.
- Returns the path of the saved file.
+ Raises:
+ AttributeError: If neither `directory` and `name` are provided nor a `source` path is set.
"""
if (not directory or not name) and not self.source:
raise AttributeError('name and directory parameters are required')
@@ -266,18 +399,22 @@ def save(self, directory: Optional[str] = None, name: Optional[str] = None) -> s
source_directory, source_name = self.source_location()
location = directory if directory else source_directory
filename = name if name else source_name
- path = Path(location, filename)
+ path = Path(location, filename) # type: ignore
with open(path, 'w', encoding=DEFAULT_ENCODING) as json_file:
- dump(dict(self.entries()), json_file, ensure_ascii=False, indent=2)
+ dump(dict(self.entries()), json_file,
+ ensure_ascii=False, indent=2) # type: ignore
return str(path)
- def add_missing_items(self, parameters) -> bool:
- """Given another Parameters instance, add any items that are not already
- present. Existing items will not be updated.
+ def add_missing_items(self, parameters: 'Parameters') -> bool:
+ """Given another `Parameters` instance, adds any items that are not already present.
- parameters: Parameters - object from which to add parameters.
+ Existing items will not be updated.
- Returns bool indicating whether or not any new items were added.
+ Args:
+ parameters (Parameters): Object from which to add parameters.
+
+ Returns:
+ bool: True if any new items were added, False otherwise.
"""
updated = False
existing_keys = self.keys()
@@ -287,17 +424,20 @@ def add_missing_items(self, parameters) -> bool:
updated = True
return updated
- def diff(self, parameters) -> Dict[str, ParameterChange]:
+ def diff(self, parameters: 'Parameters') -> Dict[str, ParameterChange]:
"""Lists the differences between this and another set of parameters.
A None original_value indicates a new parameter.
- Parameters
- ----------
- parameters : Parameters - set of parameters for comparison; these
+ Args:
+ parameters (Parameters): Set of parameters for comparison; these
are considered the original values and the current set the
changed values.
+
+ Returns:
+ Dict[str, ParameterChange]: A dictionary where keys are parameter names
+ and values are ParameterChange objects.
"""
- diffs = {}
+ diffs: Dict[str, ParameterChange] = {}
for key, param in self.entries():
if key in parameters.keys():
@@ -310,9 +450,15 @@ def diff(self, parameters) -> Dict[str, ParameterChange]:
original_value=None)
return diffs
- def instantiate(self, named_tuple_class: NamedTuple) -> NamedTuple:
- """Instantiate a namedtuple whose fields represent a subset of the
- parameters."""
+ def instantiate(self, named_tuple_class: type[NamedTuple]) -> NamedTuple:
+ """Instantiates a `NamedTuple` whose fields represent a subset of the parameters."
+
+ Args:
+ named_tuple_class (type[NamedTuple]): The `NamedTuple` class to instantiate.
+
+ Returns:
+ NamedTuple: An instance of the provided `NamedTuple` class with values populated from parameters.
+ """
vals = [
self.cast_value(self.get_entry(key))
for key in named_tuple_class._fields
@@ -321,12 +467,14 @@ def instantiate(self, named_tuple_class: NamedTuple) -> NamedTuple:
def changes_from_default(source: str) -> Dict[str, ParameterChange]:
- """Determines which parameters have changed from the default params.
+ """Determines which parameters have changed from the default parameters.
- Parameters
- ----------
- source - path to the parameters json file that will be compared with
+ Args:
+ source (str): Path to the parameters JSON file that will be compared with
the default parameters.
+
+ Returns:
+ Dict[str, ParameterChange]: A dictionary of `ParameterChange` objects representing the differences.
"""
default = Parameters(source=DEFAULT_PARAMETERS_PATH, cast_values=True)
params = Parameters(source=source, cast_values=True)
diff --git a/bcipy/core/raw_data.py b/bcipy/core/raw_data.py
index 60f04d728..6a05e2da6 100644
--- a/bcipy/core/raw_data.py
+++ b/bcipy/core/raw_data.py
@@ -15,10 +15,25 @@
class RawData:
- """Represents the raw data format used by BciPy. Used primarily for loading
- a raw data file into memory."""
+ """Represents the raw data format used by BciPy.
+
+ This class is used primarily for loading a raw data file into memory. It provides
+ methods to access and manipulate the data in various formats.
+
+ Attributes:
+ daq_type (str): Type of data acquisition device.
+ sample_rate (int): Sample rate in Hz.
+ columns (List[str]): List of column names in the data.
+ """
def __init__(self, daq_type: str, sample_rate: int, columns: List[str]) -> None:
+ """Initialize RawData.
+
+ Args:
+ daq_type (str): Type of data acquisition device.
+ sample_rate (int): Sample rate in Hz.
+ columns (List[str]): List of column names in the data.
+ """
self.daq_type = daq_type
self.sample_rate = sample_rate
self.columns = columns
@@ -27,44 +42,70 @@ def __init__(self, daq_type: str, sample_rate: int, columns: List[str]) -> None:
self._rows: List[Any] = []
@classmethod
- def load(cls, filename: str):
+ def load(cls, filename: str) -> 'RawData':
"""Constructs a RawData object by deserializing the given file.
+
All data will be read into memory. If you want to lazily read data one
record at a time, use a RawDataReader.
- Parameters
- ----------
- - filename : path to the csv file to read
+ Args:
+ filename (str): Path to the csv file to read.
+
+ Returns:
+ RawData: A new RawData instance containing the loaded data.
"""
return load(filename)
@property
def rows(self) -> List[List]:
- """Returns the data rows"""
+ """Returns the data rows.
+
+ Returns:
+ List[List]: List of data rows.
+ """
return self._rows
@rows.setter
def rows(self, value: Any) -> None:
+ """Sets the data rows and invalidates the cached dataframe.
+
+ Args:
+ value (Any): New data rows to set.
+ """
self._rows = value
self._dataframe = None
@property
def channels(self) -> List[str]:
- """Compute the list of channels. Channels are the numeric columns
- excluding the timestamp column."""
+ """Compute the list of channels.
+
+ Channels are the numeric columns excluding the timestamp column.
+ Returns:
+ List[str]: List of channel names.
+ """
# Start data slice at 1 to remove the timestamp column.
return list(self.numeric_data.columns[1:])
@property
def numeric_data(self) -> pd.DataFrame:
- """Data for columns with numeric data. This is usually comprised of the
- timestamp column and device channels, excluding string triggers."""
+ """Data for columns with numeric data.
+
+ This is usually comprised of the timestamp column and device channels,
+ excluding string triggers.
+
+ Returns:
+ pd.DataFrame: DataFrame containing only numeric columns.
+ """
return self.dataframe.select_dtypes(exclude=['object'])
@property
def channel_data(self) -> np.ndarray:
- """Data for columns with numeric data, excluding the timestamp column."""
+ """Data for columns with numeric data, excluding the timestamp column.
+
+ Returns:
+ np.ndarray: Array of channel data with shape (channels, samples).
+ """
numeric_data = self.numeric_data
numeric_vals = numeric_data.values
@@ -79,12 +120,15 @@ def by_channel(self, transform: Optional[Composition] = None) -> Tuple[np.ndarra
This will apply n tranformations to the data before returning. For an example Composition with
EEG preprocessing, see bcipy.signal.get_default_transform().
- Returns
- ----------
- data: C x N numpy array with samples where C is the number of channels and N
- is number of time samples
- fs: resulting sample rate if any transformations applied"""
+ Args:
+ transform (Optional[Composition]): Optional transformation to apply to the data.
+ Returns:
+ Tuple[np.ndarray, int]: A tuple containing:
+ - data: C x N numpy array with samples where C is the number of channels and N
+ is number of time samples
+ - fs: resulting sample rate if any transformations applied
+ """
data = self.channel_data
fs = self.sample_rate
@@ -97,21 +141,29 @@ def by_channel_map(
self,
channel_map: List[int],
transform: Optional[Composition] = None) -> Tuple[np.ndarray, List[str], int]:
- """By Channel Map.
+ """Returns channels with columns removed if index in list (channel_map) is zero.
- Returns channels with columns removed if index in list (channel_map) is zero. The channel map must align
- with the numeric channels read in as self.channels. We assume most trigger or other string columns are
- removed, however some are numeric trigger columns from devices that will require filtering before returning
+ The channel map must align with the numeric channels read in as self.channels.
+ We assume most trigger or other string columns are removed, however some are
+ numeric trigger columns from devices that will require filtering before returning
data. Other cases could be dropping bad channels before running further analyses.
- Optionally, it can apply a BciPy Composition to the data before returning using the transform arg.
- This will apply n tranformations to the data before returning. For an example Composition with
- EEG preprocessing, see bcipy.signal.get_default_transform().
+ Args:
+ channel_map (List[int]): List of 1s and 0s indicating which channels to keep.
+ transform (Optional[Composition]): Optional transformation to apply to the data.
+
+ Returns:
+ Tuple[np.ndarray, List[str], int]: A tuple containing:
+ - data: Array of channel data with shape (channels, samples)
+ - channels: List of channel names
+ - fs: Sample rate
"""
data, fs = self.by_channel(transform)
- channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
+ channels_to_remove = [idx for idx,
+ value in enumerate(channel_map) if value == 0]
data = np.delete(data, channels_to_remove, axis=0)
- channels: List[str] = np.delete(self.channels, channels_to_remove, axis=0).tolist()
+ channels: List[str] = np.delete(
+ self.channels, channels_to_remove, axis=0).tolist()
return data, channels, fs
@@ -119,13 +171,26 @@ def apply_transform(self, data: np.ndarray, transform: Composition) -> Tuple[np.
"""Apply Transform.
Using data provided as an np.ndarray, call the Composition with self.sample_rate to apply
- transformations to the data. This will return the transformed data and resulting sample rate.
+ transformations to the data. This will return the transformed data and resulting sample rate.
+
+ Args:
+ data (np.ndarray): Input data to transform.
+ transform (Composition): Transformation to apply.
+
+ Returns:
+ Tuple[np.ndarray, int]: A tuple containing:
+ - data: Transformed data
+ - fs: Resulting sample rate
"""
return transform(data, self.sample_rate)
@property
def dataframe(self) -> pd.DataFrame:
- """Returns a dataframe of the row data."""
+ """Returns a dataframe of the row data.
+
+ Returns:
+ pd.DataFrame: DataFrame containing all data.
+ """
if self._dataframe is None:
self._dataframe = pd.DataFrame(data=self.rows,
columns=self.columns)
@@ -133,15 +198,24 @@ def dataframe(self) -> pd.DataFrame:
@property
def total_seconds(self) -> float:
- """Total recorded seconds, defined as the diff between the first and
- last timestamp."""
+ """Total recorded seconds.
+
+ Defined as the diff between the first and last timestamp.
+
+ Returns:
+ float: Total duration in seconds.
+ """
frame = self.dataframe
col = 'lsl_timestamp'
return frame.iloc[-1][col] - frame.iloc[0][col]
@property
def total_samples(self) -> int:
- """Number of samples recorded."""
+ """Number of samples recorded.
+
+ Returns:
+ int: Total number of samples.
+ """
return int(self.dataframe.iloc[-1]['timestamp'])
def query(self,
@@ -150,15 +224,13 @@ def query(self,
column: str = 'lsl_timestamp') -> pd.DataFrame:
"""Query for a subset of data.
- Pararameters
- ------------
- start - find records greater than or equal to this value
- stop - find records less than or equal to this value
- column - column to compare to the given start and stop values
+ Args:
+ start (Optional[float]): Find records greater than or equal to this value.
+ stop (Optional[float]): Find records less than or equal to this value.
+ column (str): Column to compare to the given start and stop values.
- Returns
- -------
- Dataframe for the given slice of data
+ Returns:
+ pd.DataFrame: DataFrame for the given slice of data.
"""
dataframe = self.dataframe
start = start or dataframe.iloc[0][column]
@@ -169,24 +241,41 @@ def query(self,
def append(self, row: list) -> None:
"""Append the given row of data.
- Parameters
- ----------
- - row : row of data
+ Args:
+ row (list): Row of data to append.
"""
assert len(row) == len(self.columns), "Wrong number of columns"
self._rows.append(row)
self._dataframe = None
def __str__(self) -> str:
+ """String representation of RawData.
+
+ Returns:
+ str: String representation.
+ """
return f"RawData({self.daq_type})"
def __repr__(self) -> str:
+ """String representation of RawData.
+
+ Returns:
+ str: String representation.
+ """
return f"RawData({self.daq_type})"
def maybe_float(val: Any) -> Union[float, Any]:
- """Attempt to convert the given value to float. If conversion fails return
- as is."""
+ """Attempt to convert the given value to float.
+
+ If conversion fails return the value as is.
+
+ Args:
+ val (Any): Value to convert to float.
+
+ Returns:
+ Union[float, Any]: Float if conversion succeeds, original value otherwise.
+ """
try:
return float(val)
except ValueError:
@@ -194,47 +283,74 @@ def maybe_float(val: Any) -> Union[float, Any]:
class RawDataReader:
- """Lazily reads raw data from a file. Intended to be used as a ContextManager
- using Python's `with` keyword.
+ """Lazily reads raw data from a file.
- Example usage:
+ Intended to be used as a ContextManager using Python's `with` keyword.
- ```
- with RawDataReader(path) as reader:
- print(f"Data from ${reader.daq_type}")
- print(reader.columns)
- for row in reader:
- print(row)
- ```
-
- Parameters
- ----------
- - file_path : path to the csv file
- - convert_data : if True attempts to convert data values to floats;
- default is False
+ Example usage:
+ ```
+ with RawDataReader(path) as reader:
+ print(f"Data from ${reader.daq_type}")
+ print(reader.columns)
+ for row in reader:
+ print(row)
+ ```
+
+ Attributes:
+ file_path (str): Path to the csv file.
+ convert_data (bool): If True attempts to convert data values to floats.
+ daq_type (str): Type of data acquisition device.
+ sample_rate (int): Sample rate in Hz.
+ columns (List[str]): List of column names.
"""
_file_obj: TextIOWrapper
def __init__(self, file_path: str, convert_data: bool = False):
+ """Initialize RawDataReader.
+
+ Args:
+ file_path (str): Path to the csv file.
+ convert_data (bool, optional): If True attempts to convert data values to floats.
+ Defaults to False.
+ """
self.file_path = file_path
self.convert_data = convert_data
- def __enter__(self):
- self._file_obj = open(self.file_path, mode="r", encoding=DEFAULT_ENCODING)
+ def __enter__(self) -> 'RawDataReader':
+ """Enter the context manager.
+
+ Returns:
+ RawDataReader: Self.
+ """
+ self._file_obj = open(self.file_path, mode="r",
+ encoding=DEFAULT_ENCODING)
self.daq_type, self.sample_rate = read_metadata(self._file_obj)
self._reader = csv.reader(self._file_obj)
self.columns = next(self._reader)
return self
- def __exit__(self, *args, **kwargs):
- """Exit the context manager. Close resources"""
+ def __exit__(self, *args, **kwargs) -> None:
+ """Exit the context manager. Close resources."""
if self._file_obj:
self._file_obj.close()
- def __iter__(self):
+ def __iter__(self) -> 'RawDataReader':
+ """Return self as iterator.
+
+ Returns:
+ RawDataReader: Self.
+ """
return self
- def __next__(self):
+ def __next__(self) -> List[Any]:
+ """Get next row of data.
+
+ Returns:
+ List[Any]: Next row of data.
+
+ Raises:
+ AssertionError: If reader is not initialized.
+ """
assert self._reader, "Reader must be initialized"
row = next(self._reader)
if self.convert_data:
@@ -243,36 +359,46 @@ def __next__(self):
class RawDataWriter:
- """Writes a raw data file one row at a time without storing the records
- in memory. Intended to be used as a ContextManager using Python's `with`
- keyword.
+ """Writes a raw data file one row at a time without storing the records in memory.
- Example usage:
+ Intended to be used as a ContextManager using Python's `with` keyword.
- ```
- with RawDataWriter(path, daq_type, sample_rate, columns) as writer:
- for row in data:
- writer.writerow(row)
- ```
-
- Parameters
- ----------
- - file_path : path to the csv file
- - daq_type : name of device
- - sample_rate : sample frequency in Hz
- - columns : list of column names
+ Example usage:
+ ```
+ with RawDataWriter(path, daq_type, sample_rate, columns) as writer:
+ for row in data:
+ writer.writerow(row)
+ ```
+
+ Attributes:
+ file_path (str): Path to the csv file.
+ daq_type (str): Name of device.
+ sample_rate (float): Sample frequency in Hz.
+ columns (List[str]): List of column names.
"""
_file_obj: TextIOWrapper
def __init__(self, file_path: str, daq_type: str, sample_rate: float,
columns: List[str]) -> None:
+ """Initialize RawDataWriter.
+
+ Args:
+ file_path (str): Path to the csv file.
+ daq_type (str): Name of device.
+ sample_rate (float): Sample frequency in Hz.
+ columns (List[str]): List of column names.
+ """
self.file_path = file_path
self.daq_type = daq_type
self.sample_rate = sample_rate
self.columns = columns
- def __enter__(self):
- """Enter the context manager. Initializes the underlying data file."""
+ def __enter__(self) -> 'RawDataWriter':
+ """Enter the context manager. Initializes the underlying data file.
+
+ Returns:
+ RawDataWriter: Self.
+ """
self._file_obj = open(self.file_path,
mode='w',
encoding=DEFAULT_ENCODING,
@@ -289,28 +415,47 @@ def __enter__(self):
return self
- def __exit__(self, *args, **kwargs):
- """Exit the context manager. Close resources"""
+ def __exit__(self, *args, **kwargs) -> None:
+ """Exit the context manager. Close resources."""
self._file_obj.close()
def writerow(self, row: List) -> None:
+ """Write a single row of data.
+
+ Args:
+ row (List): Row of data to write.
+
+ Raises:
+ AssertionError: If writer is not initialized or row has wrong number of columns.
+ """
assert self._csv_writer, "Writer must be initialized"
assert len(row) == len(self.columns), "Wrong number of columns"
self._csv_writer.writerow(row)
def writerows(self, rows: List[List]) -> None:
+ """Write multiple rows of data.
+
+ Args:
+ rows (List[List]): Rows of data to write.
+ """
for row in rows:
self.writerow(row)
def load(filename: str) -> RawData:
"""Reads the file at the given path and initializes a RawData object.
+
All data will be read into memory. If you want to lazily read data one
record at a time, use a RawDataReader.
- Parameters
- ----------
- - filename : path to the csv file to read
+ Args:
+ filename (str): Path to the csv file to read.
+
+ Returns:
+ RawData: A new RawData instance containing the loaded data.
+
+ Raises:
+ BciPyCoreException: If the file is not found.
"""
# Loading all data from a csv is faster using pandas than using the
# RawDataReader.
@@ -322,33 +467,32 @@ def load(filename: str) -> RawData:
data.rows = dataframe.values.tolist()
return data
except FileNotFoundError:
- raise BciPyCoreException(f"\nError loading BciPy RawData. Valid data not found at: {filename}")
+ raise BciPyCoreException(
+ f"\nError loading BciPy RawData. Valid data not found at: {filename}")
def read_metadata(file_obj: TextIO) -> Tuple[str, int]:
- """Reads the metadata from an open raw data file and retuns the result as
- a tuple. Increments the reader.
+ """Reads the metadata from an open raw data file and returns the result as a tuple.
+
+ Increments the reader.
- Parameters
- ----------
- - file_obj : open TextIO object
+ Args:
+ file_obj (TextIO): Open TextIO object.
- Returns
- -------
- tuple of daq_type, sample_rate
+ Returns:
+ Tuple[str, int]: Tuple of (daq_type, sample_rate).
"""
daq_type = next(file_obj).strip().split(',')[1]
sample_rate = int(float(next(file_obj).strip().split(",")[1]))
return daq_type, sample_rate
-def write(data: RawData, filename: str):
+def write(data: RawData, filename: str) -> None:
"""Write the given raw data file.
- Parameters
- ----------
- - data : raw data object to write
- - filename : path to destination file.
+ Args:
+ data (RawData): Raw data object to write.
+ filename (str): Path to destination file.
"""
with RawDataWriter(filename, data.daq_type, data.sample_rate,
data.columns) as writer:
@@ -357,17 +501,14 @@ def write(data: RawData, filename: str):
def settings(filename: str) -> Tuple[str, float, List[str]]:
- """Read the daq settings from the given data file
+ """Read the daq settings from the given data file.
- Parameters
- ----------
- - filename : path to the raw data file (csv)
+ Args:
+ filename (str): Path to the raw data file (csv).
- Returns
- -------
- tuple of (acquisition type, sample_rate, columns)
+ Returns:
+ Tuple[str, float, List[str]]: Tuple of (acquisition type, sample_rate, columns).
"""
-
with RawDataReader(filename) as reader:
return reader.daq_type, reader.sample_rate, reader.columns
@@ -377,15 +518,21 @@ def sample_data(rows: int = 1000,
daq_type: str = 'SampleDevice',
sample_rate: int = 256,
triggers: List[Tuple[float, str]] = []) -> RawData:
- """Creates sample data to be written as a raw_data.csv file. The resulting data has
- a column for the timestamp, one for each channel, and a TRG column.
-
- - rows : number of sample rows to generate
- - ch_names : list of channel names
- - daq_type : metadata for the device name
- - sample_rate : device sample rate in hz
- - triggers : List of (timestamp, trigger_value) tuples to be inserted
- in the data.
+ """Creates sample data to be written as a raw_data.csv file.
+
+ The resulting data has a column for the timestamp, one for each channel,
+ and a TRG column.
+
+ Args:
+ rows (int, optional): Number of sample rows to generate. Defaults to 1000.
+ ch_names (List[str], optional): List of channel names. Defaults to ['ch1', 'ch2', 'ch3'].
+ daq_type (str, optional): Metadata for the device name. Defaults to 'SampleDevice'.
+ sample_rate (int, optional): Device sample rate in hz. Defaults to 256.
+ triggers (List[Tuple[float, str]], optional): List of (timestamp, trigger_value) tuples
+ to be inserted in the data. Defaults to [].
+
+ Returns:
+ RawData: A new RawData instance containing the sample data.
"""
channels = [name for name in ch_names if name != 'TRG']
columns = [TIMESTAMP_COLUMN] + channels + ['TRG']
@@ -407,12 +554,12 @@ def sample_data(rows: int = 1000,
def get_1020_channels() -> List[str]:
"""Returns the standard 10-20 channel names.
- Note: The 10-20 system is a standard for EEG electrode placement. The following is not a complete list of all
- possible channels, but the most common ones used in BCI research. This excludes the reference and ground channels.
+ Note: The 10-20 system is a standard for EEG electrode placement. The following
+ is not a complete list of all possible channels, but the most common ones used
+ in BCI research. This excludes the reference and ground channels.
- Returns
- -------
- list of channel names
+ Returns:
+ List[str]: List of channel names.
"""
return [
'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T3', 'C3', 'Cz', 'C4',
@@ -424,13 +571,11 @@ def get_1020_channels() -> List[str]:
def get_1020_channel_map(channels_name: List[str]) -> List[int]:
"""Returns a list of 1s and 0s indicating if the channel name is in the 10-20 system.
- Parameters
- ----------
- channels_name : list of channel names
+ Args:
+ channels_name (List[str]): List of channel names.
- Returns
- -------
- list of 1s and 0s indicating if the channel name is in the 10-20 system
+ Returns:
+ List[int]: List of 1s and 0s indicating if the channel name is in the 10-20 system.
"""
valid_channels = get_1020_channels()
return [1 if name in valid_channels else 0 for name in channels_name]
diff --git a/bcipy/core/report.py b/bcipy/core/report.py
index 1ce5504be..25e9e7c4f 100644
--- a/bcipy/core/report.py
+++ b/bcipy/core/report.py
@@ -1,7 +1,7 @@
# mypy: disable-error-code="union-attr"
import io
from abc import ABC
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
@@ -16,35 +16,58 @@
class ReportSection(ABC):
- """Report Section.
+ """Abstract base class for report sections in BciPy.
- An abstract class to handle the creation of a section in a BciPy Report.
+ This class defines the interface for creating sections in a BciPy Report.
+ All report sections must implement the `compile` method to generate their content.
"""
def compile(self) -> Flowable:
- """Compile.
+ """Compile the section into a flowable for the report.
- This method must be implemented by the child class.
- It is intented to be called on final Report build,
- as opposed to immediatley after class initiatlization,
- to compile the section into a usuable flowable for a Report.
+ This method must be implemented by child classes. It is intended to be called
+ during final report build, not immediately after class initialization.
+
+ Returns:
+ Flowable: A reportlab flowable object containing the section's content.
"""
...
def _create_header(self) -> Flowable:
+ """Create the header for the section.
+
+ Returns:
+ Flowable: A reportlab flowable object containing the section header.
+ """
...
class SignalReportSection(ReportSection):
- """Signal Report Section.
+ """Section for displaying signal analysis results in a BciPy Report.
- A class to handle the creation of a Signal Report section in a BciPy Report.
+ This section can include signal figures and artifact detection results.
+
+ Attributes:
+ figures (List[Figure]): List of matplotlib figures to include in the report.
+ report_flowables (List[Flowable]): List of reportlab flowables for the section.
+ artifact (Optional[ArtifactDetection]): Optional artifact detection results.
+ style: Reportlab style sheet for formatting.
"""
def __init__(
self,
figures: List[Figure],
artifact: Optional[ArtifactDetection] = None) -> None:
+ """Initialize SignalReportSection.
+
+ Args:
+ figures (List[Figure]): List of matplotlib figures to include.
+ artifact (Optional[ArtifactDetection], optional): Artifact detection results.
+ Defaults to None.
+
+ Raises:
+ AssertionError: If artifact is provided but analysis is not complete.
+ """
self.figures = figures
self.report_flowables: List[Flowable] = []
self.artifact = artifact
@@ -55,9 +78,10 @@ def __init__(
self.style = getSampleStyleSheet()
def compile(self) -> Flowable:
- """Compile.
+ """Compile the signal report section into a flowable.
- Compiles the Signal Report sections into a flowable that can be used to generate a Report.
+ Returns:
+ Flowable: A reportlab flowable containing the compiled section.
"""
self.report_flowables.append(self._create_header())
if self.artifact:
@@ -67,9 +91,10 @@ def compile(self) -> Flowable:
return KeepTogether(self.report_flowables)
def _create_artifact_section(self) -> Flowable:
- """Create Artifact Section.
+ """Create a section displaying artifact detection results.
- Creates a paragraph with the artifact information. This is only included if an artifact detection is provided.
+ Returns:
+ Flowable: A reportlab flowable containing artifact information and visualizations.
"""
artifact_report = []
artifacts_detected = self.artifact.dropped
@@ -91,7 +116,8 @@ def _create_artifact_section(self) -> Flowable:
if self.artifact.voltage_annotations:
voltage_artifacts = f'Voltage Artifacts: {len(self.artifact.voltage_annotations)}'
- voltage_section = Paragraph(voltage_artifacts, self.style['BodyText'])
+ voltage_section = Paragraph(
+ voltage_artifacts, self.style['BodyText'])
artifact_report.append(voltage_section)
# create a heatmap with the onset values of the voltage artifacts
@@ -104,9 +130,15 @@ def _create_artifact_section(self) -> Flowable:
return KeepTogether(artifact_report)
def _create_heatmap(self, onsets: List[float], range: Tuple[float, float], type: str) -> Image:
- """Create Heatmap.
+ """Create a heatmap visualization of artifact onsets.
- Creates a heatmap image with the onset values of the voltage artifacts.
+ Args:
+ onsets (List[float]): List of artifact onset times.
+ range (Tuple[float, float]): Time range for the heatmap.
+ type (str): Type of artifact being visualized.
+
+ Returns:
+ Image: A reportlab Image containing the heatmap.
"""
# create a heatmap with the onset values
fig, ax = plt.subplots()
@@ -120,19 +152,23 @@ def _create_heatmap(self, onsets: List[float], range: Tuple[float, float], type:
return heatmap
def _create_epochs_section(self) -> List[Image]:
- """Create Epochs Section.
+ """Create a section containing all signal figures.
- Creates a flowable image for each figure in the Signal Report.
+ Returns:
+ List[Image]: List of reportlab Images containing the signal figures.
"""
# create a flowable for each figure
flowables = [self.convert_figure_to_image(fig) for fig in self.figures]
return flowables
def convert_figure_to_image(self, fig: Figure) -> Image:
- """Convert Figure to Image.
+ """Convert a matplotlib figure to a reportlab Image.
+
+ Args:
+ fig (Figure): Matplotlib figure to convert.
- Converts a matplotlib figure to a reportlab Image.
- retrieved from: https://nicd.org.uk/knowledge-hub/creating-pdf-reports-with-reportlab-and-pandas
+ Returns:
+ Image: A reportlab Image containing the figure.
"""
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300)
@@ -141,42 +177,54 @@ def convert_figure_to_image(self, fig: Figure) -> Image:
return Image(buf, x * inch, y * inch)
def _create_header(self) -> Paragraph:
- """Create Header.
+ """Create the header for the signal report section.
- Creates a header for the Signal Report section.
+ Returns:
+ Paragraph: A reportlab Paragraph containing the section header.
"""
header = Paragraph('Signal Report', self.style['Heading3'])
return header
class SessionReportSection(ReportSection):
- """Session Report Section.
+ """Section for displaying session summary information in a BciPy Report.
- A class to handle the creation of a Session Report section in a BciPy Report using a summary dictionary.
+ Attributes:
+ summary (dict): Dictionary containing session summary information.
+ session_name (str): Name of the session or default name if not specified.
+ style: Reportlab style sheet for formatting.
+ summary_table (Optional[Flowable]): The compiled summary table.
"""
- def __init__(self, summary: dict) -> None:
+ def __init__(self, summary: Dict[str, Any]) -> None:
+ """Initialize SessionReportSection.
+
+ Args:
+ summary (Dict[str, Any]): Dictionary containing session summary information.
+ """
self.summary = summary
if 'task' in self.summary:
self.session_name = self.summary['task']
else:
self.session_name = 'Session Summary'
self.style = getSampleStyleSheet()
- self.summary_table = None
+ self.summary_table: Optional[Flowable] = None
def compile(self) -> Flowable:
- """Compile.
+ """Compile the session report section into a flowable.
- Compiles the Session Report sections into a flowable that can be used to generate a Report.
+ Returns:
+ Flowable: A reportlab flowable containing the compiled section.
"""
summary_table = self._create_summary_flowable()
self.summary_table = summary_table
return summary_table
def _create_summary_flowable(self) -> Flowable:
- """Create Summary Flowable.
+ """Create a flowable containing the session summary.
- Creates a flowable table with the summary dictionary.
+ Returns:
+ Flowable: A reportlab flowable containing the formatted summary.
"""
if self.summary:
# split the summary keys and values into a list
@@ -189,10 +237,15 @@ def _create_summary_flowable(self) -> Flowable:
summary_list = self._create_summary_text(keys, values)
return KeepTogether(summary_list)
- def _create_summary_text(self, keys: list, values: list) -> List[Paragraph]:
- """Create Summary Text.
+ def _create_summary_text(self, keys: List[str], values: List[Any]) -> List[Paragraph]:
+ """Create formatted text for the summary.
+
+ Args:
+ keys (List[str]): List of summary keys.
+ values (List[Any]): List of summary values.
- Creates a list of paragraphs with the keys and values from the provided summary.
+ Returns:
+ List[Paragraph]: List of reportlab Paragraphs containing the formatted summary.
"""
# create a table with the keys and values, adding a header
table = [self._create_header()]
@@ -202,18 +255,31 @@ def _create_summary_text(self, keys: list, values: list) -> List[Paragraph]:
return table
def _create_header(self) -> Paragraph:
- """Create Header.
+ """Create the header for the session report section.
- Creates a header for the Session Report section.
+ Returns:
+ Paragraph: A reportlab Paragraph containing the section header.
"""
- header = Paragraph(f'{self.session_name}', self.style['Heading3'])
+ header = Paragraph(
+ f'{self.session_name}', self.style['Heading3'])
return header
class Report:
- """Report.
-
- A class to handle compiling and saving a BciPy Report after at least one session.
+ """Class for compiling and saving BciPy Reports.
+
+ This class handles the creation of PDF reports containing multiple sections
+ of signal analysis and session information.
+
+ Attributes:
+ DEFAULT_NAME (str): Default name for the report file.
+ sections (List[ReportSection]): List of report sections to include.
+ elements (List[Flowable]): List of reportlab flowables for the report.
+ name (str): Name of the report file.
+ path (str): Full path where the report will be saved.
+ document (SimpleDocTemplate): Reportlab document template.
+ styles: Reportlab style sheet for formatting.
+ header (Optional[Flowable]): Report header containing logo and title.
"""
DEFAULT_NAME: str = 'BciPyReport.pdf'
@@ -223,6 +289,20 @@ def __init__(self,
name: Optional[str] = None,
sections: Optional[List[ReportSection]] = None,
autocompile: bool = False):
+ """Initialize Report.
+
+ Args:
+ save_path (str): Directory where the report will be saved.
+ name (Optional[str], optional): Name of the report file. Defaults to None.
+ sections (Optional[List[ReportSection]], optional): List of report sections.
+ Defaults to None.
+ autocompile (bool, optional): Whether to compile the report immediately.
+ Defaults to False.
+
+ Raises:
+ AssertionError: If sections is not a list or contains invalid section types.
+ AssertionError: If name does not end with .pdf.
+ """
if sections:
assert isinstance(sections, list), "Sections should be a list."
assert all(isinstance(section, ReportSection)
@@ -244,17 +324,15 @@ def __init__(self,
self.compile()
def add(self, section: ReportSection) -> None:
- """Add.
+ """Add a section to the report.
- Adds a ReportSection to the Report.
+ Args:
+ section (ReportSection): The section to add to the report.
"""
self.sections.append(section)
def compile(self) -> None:
- """Compile.
-
- Compiles the Report by adding the header and all sections to the elements list.
- """
+ """Compile the report by adding the header and all sections."""
if self.header is None:
self._construct_report_header()
header_group = KeepTogether(self.header)
@@ -263,19 +341,17 @@ def compile(self) -> None:
self.elements.append(section.compile())
def save(self) -> None:
- """Save.
-
- Exports the Report to a PDF file.
- """
+ """Save the report as a PDF file."""
self.document.build(self.elements)
def _construct_report_header(self) -> None:
- """Construct Report Header.
+ """Construct the report header with logo and title.
- Constructs the header for the Report. This should be called before adding any other elements.
- The header should consist of the CAMBI logo and a report title.
+ Raises:
+ AssertionError: If called after other elements have been added.
"""
- assert len(self.elements) < 1, "The report header should be constructed before other elements"
+ assert len(
+ self.elements) < 1, "The report header should be constructed before other elements"
report_title = Paragraph('BciPy Report', self.styles['Title'])
logo = Image(BCIPY_FULL_LOGO_PATH, hAlign='LEFT', width=170, height=50)
report_title.hAlign = 'CENTER'
diff --git a/bcipy/core/session.py b/bcipy/core/session.py
index 5fcce9267..7fa121251 100644
--- a/bcipy/core/session.py
+++ b/bcipy/core/session.py
@@ -3,38 +3,66 @@
import csv
import itertools
import json
+import logging
import os
import sqlite3
from dataclasses import dataclass, fields
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterator, List, Optional, Union
import openpyxl
from openpyxl.chart import BarChart, Reference
from openpyxl.styles import PatternFill
from openpyxl.styles.borders import BORDER_THIN, Border, Side
from openpyxl.styles.colors import COLOR_INDEX
+from openpyxl.worksheet.worksheet import Worksheet
from bcipy.config import (DEFAULT_ENCODING, DEFAULT_PARAMETERS_FILENAME,
- SESSION_DATA_FILENAME, SESSION_SUMMARY_FILENAME)
+ SESSION_DATA_FILENAME, SESSION_LOG_FILENAME,
+ SESSION_SUMMARY_FILENAME)
from bcipy.io.load import load_json_parameters
from bcipy.task.data import Session
+# Configure logging
+logger = logging.getLogger(SESSION_LOG_FILENAME)
+
BLACK = COLOR_INDEX[0]
WHITE = COLOR_INDEX[1]
YELLOW = COLOR_INDEX[5]
def read_session(file_name: str = SESSION_DATA_FILENAME) -> Session:
- """Read the session data from the given file."""
+ """Read session data from a JSON file.
+
+ Args:
+ file_name (str, optional): Path to the session data file.
+ Defaults to SESSION_DATA_FILENAME.
+
+ Returns:
+ Session: A Session object containing the parsed data.
+
+ Raises:
+ FileNotFoundError: If the specified file does not exist.
+ json.JSONDecodeError: If the file contains invalid JSON.
+ """
with open(file_name, 'r', encoding=DEFAULT_ENCODING) as json_file:
return Session.from_dict(json.load(json_file))
-def session_data(data_dir: str) -> Dict:
- """Returns a dict of session data transformed to map the alphabet letter
- to the likelihood when presenting the evidence. Also removes attributes
- not useful for debugging."""
+def session_data(data_dir: str) -> Dict[str, Any]:
+ """Transform session data to map alphabet letters to likelihood values.
+
+ Args:
+ data_dir (str): Directory containing the session data and parameters.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing transformed session data with:
+ - Mapped alphabet letters to likelihood values
+ - Target text from parameters
+ - Removed debugging attributes
+ Raises:
+ FileNotFoundError: If required files are not found in data_dir.
+ """
parameters = load_json_parameters(os.path.join(data_dir,
DEFAULT_PARAMETERS_FILENAME),
value_cast=True)
@@ -46,7 +74,22 @@ def session_data(data_dir: str) -> Dict:
@dataclass(frozen=True)
class EvidenceRecord:
- """Record summarizing Inquiry evidence."""
+ """Record summarizing Inquiry evidence.
+
+ Attributes:
+ series (int): Series number.
+ inquiry (int): Inquiry number within the series.
+ stim (str): Stimulus (letter or icon).
+ lm (float): Language model probability.
+ eeg (float): EEG evidence value.
+ eye (float): Eye tracking evidence value.
+ btn (float): Button press evidence value.
+ cumulative (float): Cumulative likelihood value.
+ inq_position (Optional[int]): Position in inquiry sequence.
+ is_target (int): Whether this is the target (1) or not (0).
+ presented (int): Whether this was presented (1) or not (0).
+ above_threshold (int): Whether above decision threshold (1) or not (0).
+ """
series: int
inquiry: int
stim: str
@@ -60,12 +103,25 @@ class EvidenceRecord:
presented: int
above_threshold: int
- def __iter__(self):
+ def __iter__(self) -> Iterator[Any]:
+ """Iterate over the record's field values.
+
+ Returns:
+ Iterator[Any]: Iterator over the record's field values.
+ """
return iter([getattr(self, field.name) for field in fields(self)])
def sqlite_ddl(cls: Any, table_name: str) -> str:
- """Sqlite create table statement for the given dataclass"""
+ """Generate SQLite CREATE TABLE statement for a dataclass.
+
+ Args:
+ cls (Any): Dataclass to generate DDL for.
+ table_name (str): Name of the table to create.
+
+ Returns:
+ str: SQLite CREATE TABLE statement.
+ """
conversions = {int: 'integer', str: 'text', float: 'real'}
column_defs = [
@@ -76,13 +132,31 @@ def sqlite_ddl(cls: Any, table_name: str) -> str:
def sqlite_insert(cls: Any, table_name: str) -> str:
- """sqlite INSERT statement for the given dataclass."""
+ """Generate SQLite INSERT statement for a dataclass.
+
+ Args:
+ cls (Any): Dataclass to generate INSERT for.
+ table_name (str): Name of the table to insert into.
+
+ Returns:
+ str: SQLite INSERT statement with placeholders.
+ """
placeholders = ['?' for _ in fields(cls)]
return f"INSERT INTO {table_name} VALUES ({','.join(placeholders)})"
def evidence_records(session: Session) -> List[EvidenceRecord]:
- """Summarize the session evidence data."""
+ """Generate evidence records from session data.
+
+ Args:
+ session (Session): Session data to process.
+
+ Returns:
+ List[EvidenceRecord]: List of evidence records.
+
+ Raises:
+ AssertionError: If session has no evidence or no symbol set.
+ """
assert session.has_evidence(
), "There is no evidence in the provided session"
assert session.symbol_set, "Session must define a symbol_set"
@@ -106,7 +180,8 @@ def evidence_records(session: Session) -> List[EvidenceRecord]:
eye=evidence.get('eye_evidence', {}).get(stim, ''),
btn=evidence.get('btn_evidence', {}).get(stim, ''),
cumulative=evidence['likelihood'][stim],
- inq_position=stimuli.index(stim) if stim in stimuli else None,
+ inq_position=stimuli.index(
+ stim) if stim in stimuli else None,
is_target=int(inquiry.target_letter == stim),
presented=int(stim in stimuli),
above_threshold=int(evidence['likelihood'][stim] >
@@ -114,34 +189,27 @@ def evidence_records(session: Session) -> List[EvidenceRecord]:
return records
-def session_db(session: Session, db_file: str = 'session.db'):
- """Creates a sqlite database from the given session data.
-
- Parameters
- ----------
- session - task data (evidence values, stim times, etc.)
- db_file - path of database to write; defaults to session.db
-
- Database Schema
- ---------------
- evidence:
- - trial integer (0-based)
- - inquiry integer (0-based)
- - letter text (letter or icon)
- - lm real (language model probability for the trial; same for every
- inquiry and only considered in the cumulative value during the
- first inquiry)
- - eeg real (likelihood for the given inquiry; a value of 1.0 indicates
- that the letter was not presented)
- - btn real (button press evidence)
- - eye real (eyetracker evidence)
- - cumulative real (cumulative likelihood for the trial thus far)
- - inq_position integer (inquiry position; null if not presented)
- - is_target integer (boolean; true(1) if this letter is the target)
- - presented integer (boolean; true if the letter was presented in
- this inquiry)
- - above_threshold (boolean; true if cumulative likelihood was above
- the configured threshold)
+def session_db(session: Session, db_file: str = 'session.db') -> None:
+ """Create a SQLite database from session data.
+
+ Args:
+ session (Session): Session data to store in database.
+ db_file (str, optional): Path to database file. Defaults to 'session.db'.
+
+ Database Schema:
+ evidence:
+ - trial integer (0-based)
+ - inquiry integer (0-based)
+ - letter text (letter or icon)
+ - lm real (language model probability)
+ - eeg real (likelihood for the inquiry)
+ - btn real (button press evidence)
+ - eye real (eyetracker evidence)
+ - cumulative real (cumulative likelihood)
+ - inq_position integer (inquiry position)
+ - is_target integer (boolean)
+ - presented integer (boolean)
+ - above_threshold integer (boolean)
"""
# Create database
conn = sqlite3.connect(db_file)
@@ -155,8 +223,13 @@ def session_db(session: Session, db_file: str = 'session.db'):
conn.commit()
-def session_csv(session: Session, csv_file='session.csv'):
- """Create a csv file summarizing the evidence data for the given session."""
+def session_csv(session: Session, csv_file: str = 'session.csv') -> None:
+ """Create a CSV file summarizing session evidence data.
+
+ Args:
+ session (Session): Session data to summarize.
+ csv_file (str, optional): Path to CSV file. Defaults to 'session.csv'.
+ """
with open(csv_file, "w", encoding=DEFAULT_ENCODING, newline='') as output:
csv_writer = csv.writer(output, delimiter=',')
@@ -166,8 +239,20 @@ def session_csv(session: Session, csv_file='session.csv'):
csv_writer.writerow(record)
-def write_row(excel_sheet, rownum, data, background=None, border=None):
- """Helper method to write a row to an Excel spreadsheet"""
+def write_row(excel_sheet: Worksheet,
+ rownum: int,
+ data: Union[EvidenceRecord, List[Any]],
+ background: Optional[PatternFill] = None,
+ border: Optional[Border] = None) -> None:
+ """Write a row to an Excel spreadsheet.
+
+ Args:
+ excel_sheet (Worksheet): Worksheet to write to.
+ rownum (int): Row number to write to.
+ data (Union[EvidenceRecord, List[Any]]): Data to write.
+ background (Optional[PatternFill], optional): Background fill. Defaults to None.
+ border (Optional[Border], optional): Cell border. Defaults to None.
+ """
for col, val in enumerate(data, start=1):
cell = excel_sheet.cell(row=rownum, column=col)
cell.value = val
@@ -178,10 +263,15 @@ def write_row(excel_sheet, rownum, data, background=None, border=None):
def session_excel(session: Session,
- excel_file=SESSION_SUMMARY_FILENAME,
- include_charts=True):
- """Create an Excel spreadsheet summarizing the evidence data for the given session."""
-
+ excel_file: str = SESSION_SUMMARY_FILENAME,
+ include_charts: bool = True) -> None:
+ """Create an Excel spreadsheet summarizing session evidence data.
+
+ Args:
+ session (Session): Session data to summarize.
+ excel_file (str, optional): Path to Excel file. Defaults to SESSION_SUMMARY_FILENAME.
+ include_charts (bool, optional): Whether to include charts. Defaults to True.
+ """
# Define styles and borders to use within the spreadsheet.
gray_background = PatternFill(start_color='ededed', fill_type='solid')
white_background = PatternFill(start_color=WHITE, fill_type=None)
@@ -280,7 +370,7 @@ def session_excel(session: Session,
# Freeze header row
sheet.freeze_panes = 'A2'
workbook.save(excel_file)
- print("Wrote output to " + excel_file)
+ logger.info("Wrote output to %s", excel_file)
if __name__ == "__main__":
diff --git a/bcipy/core/stimuli.py b/bcipy/core/stimuli.py
index 58de25f43..b466cd1d4 100644
--- a/bcipy/core/stimuli.py
+++ b/bcipy/core/stimuli.py
@@ -35,45 +35,62 @@
class StimuliOrder(Enum):
- """Stimuli Order.
+ """Defines the ordering of stimuli for inquiry.
- Enum to define the ordering of stimuli for inquiry.
+ Attributes:
+ RANDOM (str): Random ordering of stimuli.
+ ALPHABETICAL (str): Alphabetical ordering of stimuli.
"""
RANDOM = 'random'
ALPHABETICAL = 'alphabetical'
@classmethod
- def list(cls):
- """Returns all enum values as a list"""
+ def list(cls) -> list:
+ """Returns all enum values as a list.
+
+ Returns:
+ list: List of all enum values.
+ """
return list(map(lambda c: c.value, cls))
class TargetPositions(Enum):
- """Target Positions.
+ """Defines the positions of targets within the inquiry.
- Enum to define the positions of targets within the inquiry.
+ Attributes:
+ RANDOM (str): Random positioning of targets.
+ DISTRIBUTED (str): Evenly distributed positioning of targets.
"""
RANDOM = 'random'
DISTRIBUTED = 'distributed'
@classmethod
- def list(cls):
- """Returns all enum values as a list"""
+ def list(cls) -> list:
+ """Returns all enum values as a list.
+
+ Returns:
+ list: List of all enum values.
+ """
return list(map(lambda c: c.value, cls))
class PhotoDiodeStimuli(Enum):
- """Photodiode Stimuli.
+ """Defines unicode stimuli needed for testing system timing.
- Enum to define unicode stimuli needed for testing system timing.
+ Attributes:
+ EMPTY (str): Box with a white border, no fill (□).
+ SOLID (str): Solid white box (■).
"""
-
EMPTY = '\u25A1' # box with a white border, no fill
SOLID = '\u25A0' # solid white box
@classmethod
- def list(cls):
- """Returns all enum values as a list"""
+ def list(cls) -> list:
+ """Returns all enum values as a list.
+
+ Returns:
+ list: List of all enum values.
+ """
return list(map(lambda c: c.value, cls))
@@ -81,19 +98,21 @@ class InquirySchedule(NamedTuple):
"""Schedule for the next inquiries to present, where each inquiry specifies
the stimulus, duration, and color information.
- Attributes
- ----------
- - stimuli: `List[List[str]]`
- - durations: `List[List[float]]`
- - colors: `List[List[str]]`
+ Attributes:
+ stimuli (List[Any]): List of stimuli for each inquiry.
+ durations (Union[List[List[float]], List[float]]): Duration for each stimulus.
+ colors (Union[List[List[str]], List[str]]): Color for each stimulus.
"""
stimuli: List[Any]
durations: Union[List[List[float]], List[float]]
colors: Union[List[List[str]], List[str]]
def inquiries(self) -> Iterator[Tuple]:
- """Generator that iterates through each Inquiry. Yields tuples of
- (stim, duration, color)."""
+ """Generator that iterates through each Inquiry.
+
+ Yields:
+ Tuple: Tuple of (stim, duration, color) for each inquiry.
+ """
count = len(self.stimuli)
index = 0
while index < count:
@@ -103,13 +122,25 @@ def inquiries(self) -> Iterator[Tuple]:
class Reshaper(ABC):
+ """Abstract base class for reshaping data in BCI experiments."""
@abstractmethod
def __call__(self, *args, **kwargs) -> Any:
+ """Reshape data for a specific paradigm.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ Any: Reshaped data.
+ """
...
class InquiryReshaper:
+ """Reshapes EEG data, timing, and labels for inquiries in BCI experiments."""
+
def __call__(self,
trial_targetness_label: List[str],
timing_info: List[float],
@@ -125,30 +156,28 @@ def __call__(self,
"""Extract inquiry data and labels.
Args:
- trial_targetness_label (List[str]): labels each trial as "target", "non-target", "first_pres_target", etc
- timing_info (List[float]): Timestamp of each event in seconds
- eeg_data (np.ndarray): shape (channels, samples) preprocessed EEG data
- sample_rate (int): sample rate of data provided in eeg_data
- trials_per_inquiry (int): number of trials in each inquiry
+ trial_targetness_label (List[str]): Labels each trial as "target", "non-target", etc.
+ timing_info (List[float]): Timestamp of each event in seconds.
+ eeg_data (np.ndarray): Shape (channels, samples) preprocessed EEG data.
+ sample_rate (int): Sample rate of data provided in eeg_data.
+ trials_per_inquiry (int): Number of trials in each inquiry.
offset (float, optional): Any calculated or hypothesized offsets in timings. Defaults to 0.
- channel_map (List[int], optional): Describes which channels to include or discard.
- Defaults to None; all channels will be used.
- poststimulus_length (float, optional): time in seconds needed after the last trial in an inquiry.
- prestimulus_length (float, optional): time in seconds needed before the first trial in an inquiry.
- transformation_buffer (float, optional): time in seconds to buffer the end of the inquiry. Defaults to 0.0.
- target_label (str): label of target symbol. Defaults to "target"
+ channel_map (List[int], optional): Describes which channels to include or discard. Defaults to None.
+ poststimulus_length (float, optional): Time in seconds needed after the last trial. Defaults to 0.5.
+ prestimulus_length (float, optional): Time in seconds needed before the first trial. Defaults to 0.0.
+ transformation_buffer (float, optional): Time in seconds to buffer the end of the inquiry. Defaults to 0.0.
+ target_label (str, optional): Label of target symbol. Defaults to "target".
Returns:
- reshaped_data (np.ndarray): inquiry data of shape (Channels, Inquiries, Samples)
- labels (np.ndarray): integer label for each inquiry. With `trials_per_inquiry=K`,
- a label of [0, K-1] indicates the position of `target_label`, or label of [0 ... 0] indicates
- `target_label` was not present.
- reshaped_trigger_timing (List[List[int]]): For each inquiry, a list of the sample index where each trial
- begins, accounting for the prestim buffer that may have been added to the front of each inquiry.
+ Tuple[np.ndarray, np.ndarray, List[List[float]]]:
+ - reshaped_data: Inquiry data of shape (Channels, Inquiries, Samples)
+ - labels: Integer label for each inquiry
+ - reshaped_trigger_timing: For each inquiry, a list of the sample index where each trial begins
"""
if channel_map:
# Remove the channels that we are not interested in
- channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
+ channels_to_remove = [idx for idx,
+ value in enumerate(channel_map) if value == 0]
eeg_data = np.delete(eeg_data, channels_to_remove, axis=0)
n_inquiry = len(timing_info) // trials_per_inquiry
@@ -156,15 +185,18 @@ def __call__(self,
prestimulus_samples = int(prestimulus_length * sample_rate)
# triggers in seconds are mapped to triggers in number of samples.
- triggers = list(map(lambda x: int((x + offset) * sample_rate), timing_info))
+ triggers = list(
+ map(lambda x: int((x + offset) * sample_rate), timing_info))
# First, find the longest inquiry in this experiment
# We'll add or remove a few samples from all other inquiries, to match this length
def get_inquiry_len(inq_trigs):
return inq_trigs[-1] - inq_trigs[0]
- longest_inquiry = max(grouper(triggers, trials_per_inquiry, fillvalue='x'), key=lambda xy: get_inquiry_len(xy))
- num_samples_per_inq = get_inquiry_len(longest_inquiry) + trial_duration_samples
+ longest_inquiry = max(grouper(
+ triggers, trials_per_inquiry, fillvalue='x'), key=lambda xy: get_inquiry_len(xy))
+ num_samples_per_inq = get_inquiry_len(
+ longest_inquiry) + trial_duration_samples
buffer_samples = int(transformation_buffer * sample_rate)
# Label for every inquiry
@@ -173,7 +205,8 @@ def get_inquiry_len(inq_trigs):
) # maybe this can be configurable? return either class indexes or labels ('nontarget' etc)
reshaped_data, reshaped_trigger_timing = [], []
for inquiry_idx, trials_within_inquiry in enumerate(
- grouper(zip(trial_targetness_label, triggers), trials_per_inquiry, fillvalue='x')
+ grouper(zip(trial_targetness_label, triggers),
+ trials_per_inquiry, fillvalue='x')
):
first_trigger = trials_within_inquiry[0][1]
@@ -184,7 +217,8 @@ def get_inquiry_len(inq_trigs):
# If prestimulus buffer is used, we add it here so that trigger timings will
# still line up with trial onset
- trial_triggers.append((trigger - first_trigger) + prestimulus_samples)
+ trial_triggers.append(
+ (trigger - first_trigger) + prestimulus_samples)
reshaped_trigger_timing.append(trial_triggers)
start = first_trigger - prestimulus_samples
stop = first_trigger + num_samples_per_inq + buffer_samples
@@ -198,29 +232,19 @@ def extract_trials(
samples_per_trial: int,
inquiry_timing: List[List[int]],
prestimulus_samples: int = 0) -> np.ndarray:
- """Extract Trials.
-
- After using the InquiryReshaper, it may be necessary to further extract the trials for processing.
- Using the number of samples and inquiry timing, the data is reshaped from Channels, Inquiry, Samples to
- Channels, Trials, Samples. These should match with the trials extracted from the TrialReshaper given the same
- slicing parameters.
-
- Parameters
- ----------
- inquiries : np.ndarray
- shape (Channels, Inquiries, Samples)
- samples_per_trial : int
- number of samples per trial
- inquiry_timing : List[List[int]]
- For each inquiry, a list of the sample index where each trial begins
- prestimulus_samples : int, optional
- Number of samples to move the start of each trial in each inquiry, by default 0.
- This is useful if wanting to use baseline intervals before the trial onset, along with the trial data.
-
- Returns
- -------
- np.ndarray
- shape (Channels, Trials, Samples)
+ """Extract trials from inquiry data.
+
+ Args:
+ inquiries (np.ndarray): Shape (Channels, Inquiries, Samples).
+ samples_per_trial (int): Number of samples per trial.
+ inquiry_timing (List[List[int]]): For each inquiry, a list of the sample index where each trial begins.
+ prestimulus_samples (int, optional): Number of samples to move the start of each trial. Defaults to 0.
+
+ Returns:
+ np.ndarray: Shape (Channels, Trials, Samples).
+
+ Raises:
+ BciPyCoreException: If index is out of bounds when extracting trials.
"""
new_trials = []
num_inquiries = inquiries.shape[1]
@@ -242,6 +266,8 @@ def extract_trials(
class GazeReshaper:
+ """Reshapes gaze trajectory data and labels for inquiries in BCI experiments."""
+
def __call__(self,
inq_start_times: List[float],
target_symbols: List[str],
@@ -254,36 +280,21 @@ def __call__(self,
) -> Tuple[dict, list, List[str]]:
"""Extract gaze trajectory data and labels.
- Different from the EEG, gaze inquiry windows start with the first highlighted symbol and end with the
- last highlighted symbol in the inquiry. Each inquiry has a length of (trial duration x num of trials)
- seconds. Labels are provided in 'target_symbols'. It returns a Dict, where keys are the target symbols
- and the values are inquiries (appended in order of appearance) where the corresponding target symbol is
- prompted.
-
- Optional outputs:
- reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are
- appended in chronological order.
- labels returns the list of target symbols in each inquiry.
-
- Parameters
- ----------
- inq_start_times (List[float]): Timestamp of each event in seconds
- target_symbols (List[str]): Prompted symbol in each inquiry
- gaze_data (np.ndarray): shape (channels, samples) eye tracking data
- sample_rate (int): sample rate of eye tracker data
- stimulus_duration (float): duration of flash time (in seconds) for each trial
- num_stimuli_per_inquiry (int): number of stimuli in each inquiry (default: 10)
- symbol_set (List[str]): list of all symbols for the task
- channel_map (List[int], optional): Describes which channels to include or discard.
- Defaults to None; all channels will be used.
-
- Returns
- -------
- data_by_targets (dict): Dictionary where keys consist of the symbol set, and values
- the appended inquiries for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)
-
- reshaped_data (List[float]) [optional]: inquiry data of shape (Inquiries, Channels, Samples)
- labels (List[str]) [optional] : Target symbol in each inquiry.
+ Args:
+ inq_start_times (List[float]): Timestamp of each event in seconds.
+ target_symbols (List[str]): Prompted symbol in each inquiry.
+ gaze_data (np.ndarray): Shape (channels, samples) eye tracking data.
+ sample_rate (int): Sample rate of eye tracker data.
+ stimulus_duration (float): Duration of flash time (in seconds) for each trial.
+ num_stimuli_per_inquiry (int): Number of stimuli in each inquiry.
+ symbol_set (List[str], optional): List of all symbols for the task. Defaults to alphabet().
+ channel_map (List[int], optional): Describes which channels to include or discard. Defaults to None.
+
+ Returns:
+ Tuple[dict, list, List[str]]:
+ - data_by_targets: Dictionary where keys are the symbol set, and values are the appended inquiries for each symbol.
+ - reshaped_data: Inquiry data of shape (Inquiries, Channels, Samples).
+ - labels: Target symbol in each inquiry.
"""
# Find the timestamp value closest to (& greater than) inq_start_times.
# Lsl timestamps are the last row in the gaze_data
@@ -313,7 +324,8 @@ def __call__(self,
# A better way of handling this buffer would be subtracting the flash time of the
# second symbol from the first symbol, which gives a more accurate representation of
# "stimulus duration".
- window_length = (stimulus_duration + buffer) * num_stimuli_per_inquiry # in seconds
+ window_length = (stimulus_duration + buffer) * \
+ num_stimuli_per_inquiry # in seconds
reshaped_data = []
# Merge the inquiries if they have the same target letter:
@@ -332,13 +344,14 @@ def __call__(self,
return data_by_targets_dict, reshaped_data, labels
def centralize_all_data(self, data: np.ndarray, symbol_pos: np.ndarray) -> np.ndarray:
- """ Using the symbol locations in matrix, centralize all data (in Tobii units).
- This data will only be used in certain model types.
+ """Centralize all data using symbol locations in matrix (Tobii units).
+
Args:
- data (np.ndarray): Data in shape of num_samples x num_dimensions
- symbol_pos (np.ndarray(float)): Array of the current symbol posiiton in Tobii units
+ data (np.ndarray): Data in shape of num_samples x num_dimensions.
+ symbol_pos (np.ndarray): Array of the current symbol position in Tobii units.
+
Returns:
- new_data (np.ndarray): Centralized data in shape of num_samples x num_dimensions
+ np.ndarray: Centralized data in shape of num_samples x num_dimensions.
"""
new_data = np.copy(data)
for i in range(len(data)):
@@ -348,6 +361,8 @@ def centralize_all_data(self, data: np.ndarray, symbol_pos: np.ndarray) -> np.nd
class TrialReshaper(Reshaper):
+ """Reshapes EEG data, timing, and labels for individual trials in BCI experiments."""
+
def __call__(self,
trial_targetness_label: list,
timing_info: list,
@@ -360,28 +375,26 @@ def __call__(self,
target_label: str = "target") -> Tuple[np.ndarray, np.ndarray]:
"""Extract trial data and labels.
- Parameters
- ----------
- trial_targetness_label (list): labels each trial as "target", "non-target", "first_pres_target", etc
- timing_info (list): Timestamp of each event in seconds
- eeg_data (np.ndarray): shape (channels, samples) preprocessed EEG data
- sample_rate (int): sample rate of preprocessed EEG data
- trials_per_inquiry (int, optional): unused, kept here for consistent interface with `inquiry_reshaper`
- offset (float, optional): Any calculated or hypothesized offsets in timings.
- Defaults to 0.
- channel_map (List, optional): Describes which channels to include or discard.
- Defaults to None; all channels will be used.
- poststimulus_length (float, optional): [description]. Defaults to 0.5.
- target_label (str): label of target symbol. Defaults to "target"
-
- Returns
- -------
- trial_data (np.ndarray): shape (channels, trials, samples) reshaped data
- labels (np.ndarray): integer label for each trial
+ Args:
+ trial_targetness_label (list): Labels each trial as "target", "non-target", etc.
+ timing_info (list): Timestamp of each event in seconds.
+ eeg_data (np.ndarray): Shape (channels, samples) preprocessed EEG data.
+ sample_rate (int): Sample rate of preprocessed EEG data.
+ offset (float, optional): Any calculated or hypothesized offsets in timings. Defaults to 0.
+ channel_map (List, optional): Describes which channels to include or discard. Defaults to None.
+ poststimulus_length (float, optional): Time in seconds needed after the last trial. Defaults to 0.5.
+ prestimulus_length (float, optional): Time in seconds needed before the first trial. Defaults to 0.0.
+ target_label (str, optional): Label of target symbol. Defaults to "target".
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ - trial_data: Reshaped data of shape (channels, trials, samples)
+ - labels: Integer label for each trial
"""
# Remove the channels that we are not interested in
if channel_map:
- channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
+ channels_to_remove = [idx for idx,
+ value in enumerate(channel_map) if value == 0]
eeg_data = np.delete(eeg_data, channels_to_remove, axis=0)
# Number of samples we are interested per trial
@@ -389,7 +402,8 @@ def __call__(self,
prestim_samples = int(prestimulus_length * sample_rate)
# triggers in seconds are mapped to triggers in number of samples.
- triggers = list(map(lambda x: int((x + offset) * sample_rate), timing_info))
+ triggers = list(
+ map(lambda x: int((x + offset) * sample_rate), timing_info))
# Label for every trial in 0 or 1
targetness_labels = np.zeros(len(triggers), dtype=np.longlong)
@@ -399,13 +413,22 @@ def __call__(self,
targetness_labels[trial_idx] = 1
# For every channel append filtered channel data to trials
- reshaped_trials.append(eeg_data[:, trigger - prestim_samples: trigger + poststim_samples])
+ reshaped_trials.append(
+ eeg_data[:, trigger - prestim_samples: trigger + poststim_samples])
return np.stack(reshaped_trials, 1), targetness_labels
def update_inquiry_timing(timing: List[List[int]], downsample: int) -> List[List[int]]:
- """Update inquiry timing to reflect downsampling."""
+ """Update inquiry timing to reflect downsampling.
+
+ Args:
+ timing (List[List[int]]): Original timing values for each inquiry.
+ downsample (int): Downsampling factor.
+
+ Returns:
+ List[List[int]]: Updated timing values for each inquiry.
+ """
for i, inquiry in enumerate(timing):
for j, time in enumerate(inquiry):
timing[i][j] = int(time // downsample)
@@ -420,14 +443,24 @@ def mne_epochs(mne_data: RawArray,
baseline: Optional[Tuple[Any, float]] = None,
reject_by_annotation: bool = False,
preload: bool = False) -> Epochs:
- """MNE Epochs.
+ """Create MNE Epochs from a RawArray and trigger information.
- Using an MNE RawArray, reshape the data given trigger information. If two labels present [0, 1],
- each may be accessed by numbered order. Ex. first_class = epochs['1'], second_class = epochs['2']
+ Args:
+ mne_data (RawArray): MNE RawArray object.
+ trial_length (float): Length of each trial in seconds.
+ trigger_timing (Optional[List[float]], optional): List of trigger times. Defaults to None.
+ trigger_labels (Optional[List[int]], optional): List of trigger labels. Defaults to None.
+ baseline (Optional[Tuple[Any, float]], optional): Baseline interval. Defaults to None.
+ reject_by_annotation (bool, optional): Whether to reject epochs by annotation. Defaults to False.
+ preload (bool, optional): Whether to preload the data. Defaults to False.
+
+ Returns:
+ Epochs: MNE Epochs object.
"""
old_annotations = mne_data.annotations
if trigger_timing and trigger_labels:
- new_annotations = Annotations(trigger_timing, [trial_length] * len(trigger_timing), trigger_labels)
+ new_annotations = Annotations(
+ trigger_timing, [trial_length] * len(trigger_timing), trigger_labels)
all_annotations = new_annotations + old_annotations
else:
all_annotations = old_annotations
@@ -449,15 +482,20 @@ def mne_epochs(mne_data: RawArray,
baseline=baseline,
tmax=trial_length,
tmin=tmin,
- proj=False, # apply SSP projection to data. Defaults to True in Epochs.
+ # apply SSP projection to data. Defaults to True in Epochs.
+ proj=False,
reject_by_annotation=reject_by_annotation,
preload=preload)
def alphabetize(stimuli: List[str]) -> List[str]:
- """Alphabetize.
+ """Return a list of sorted stimuli by alphabet.
- Given a list of string stimuli, return a list of sorted stimuli by alphabet.
+ Args:
+ stimuli (List[str]): List of string stimuli.
+
+ Returns:
+ List[str]: Alphabetically sorted list of stimuli.
"""
return sorted(stimuli, key=lambda x: re.sub(r'[^a-zA-Z0-9 \n\.]', 'ZZ', x).lower())
@@ -469,23 +507,20 @@ def inq_generator(query: List[str],
stim_jitter: float = 0,
stim_order: StimuliOrder = StimuliOrder.RANDOM,
is_txt: bool = True) -> InquirySchedule:
- """Given the query set, prepares the stimuli, color and timing
-
- Parameters
- ----------
- query(list[str]): list of queries to be shown
- timing(list[float]): Task specific timing for generator
- color(list[str]): Task specific color for generator
- First element is the target, second element is the fixation
- Observe that [-1] element represents the trial information
- Return
- ------
- schedule_inq(tuple(
- samples[list[list[str]]]: list of inquiries
- timing(list[list[float]]): list of timings
- color(list(list[str])): list of colors)): scheduled inquiries
- """
+ """Prepare the stimuli, color, and timing for a set of inquiries.
+ Args:
+ query (List[str]): List of queries to be shown.
+ timing (List[float], optional): Task specific timing for generator. Defaults to [1, 0.2].
+ color (List[str], optional): Task specific color for generator. Defaults to ['red', 'white'].
+ inquiry_count (int, optional): Number of inquiries to generate. Defaults to 1.
+ stim_jitter (float, optional): Jitter to apply to stimulus timing. Defaults to 0.
+ stim_order (StimuliOrder, optional): Ordering of stimuli. Defaults to StimuliOrder.RANDOM.
+ is_txt (bool, optional): Whether the stimuli are text. Defaults to True.
+
+ Returns:
+ InquirySchedule: Scheduled inquiries with samples, timing, and color.
+ """
if stim_order == StimuliOrder.ALPHABETICAL:
query = alphabetize(query)
else:
@@ -496,18 +531,14 @@ def inq_generator(query: List[str],
# Init some lists to construct our stimuli with
samples, times, colors = [], [], []
for _ in range(inquiry_count):
-
# append a fixation cross. if not text, append path to image fixation
sample = [get_fixation(is_txt)]
-
# construct the sample from the query
sample += [i for i in query]
samples.append(sample)
-
times.append([timing[i] for i in range(len(timing) - 1)])
base_timing = timing[-1]
times[-1] += jittered_timing(base_timing, stim_jitter, stim_length)
-
# append colors
colors.append([color[i] for i in range(len(color) - 1)] +
[color[-1]] * stim_length)
@@ -518,33 +549,26 @@ def best_selection(selection_elements: list,
val: list,
len_query: int,
always_included: Optional[List[str]] = None) -> list:
- """Best Selection.
-
- Given set of elements and a value function over the set, picks the len_query
- number of elements with the best value.
-
- Parameters
- ----------
- selection_elements(list[str]): the set of elements
- val(list[float]): values for the corresponding elements
- len_query(int): number of elements to be picked from the set
- always_included(list[str]): subset of elements that should always be
- included in the result. Defaults to None.
- Return
- ------
- best_selection(list[str]): elements from selection_elements with the best values
- """
+ """Pick the len_query number of elements with the best value.
+ Args:
+ selection_elements (list): The set of elements.
+ val (list): Values for the corresponding elements.
+ len_query (int): Number of elements to be picked from the set.
+ always_included (Optional[List[str]], optional): Subset of elements that should always be included. Defaults to None.
+
+ Returns:
+ list: Elements from selection_elements with the best values.
+ """
always_included = always_included or []
# pick the top n items sorted by value in decreasing order
elem_val = dict(zip(selection_elements, val))
- best = sorted(selection_elements, key=elem_val.get, reverse=True)[0:len_query]
-
+ best = sorted(selection_elements, key=elem_val.get,
+ reverse=True)[0:len_query]
replacements = [
item for item in always_included
if item not in best and item in selection_elements
][0:len_query]
-
if replacements:
best[-len(replacements):] = replacements
return best
@@ -559,67 +583,49 @@ def best_case_rsvp_inq_gen(alp: list,
stim_order: StimuliOrder = StimuliOrder.RANDOM,
is_txt: bool = True,
inq_constants: Optional[List[str]] = None) -> InquirySchedule:
- """Best Case RSVP Inquiry Generation.
-
- Generates RSVPKeyboard inquiry by picking n-most likely letters.
-
- Parameters
- ----------
- alp(list[str]): alphabet (can be arbitrary)
- session_stimuli(ndarray[float]): quantifier metric for query selection
- dim(session_stimuli) = card(alp)!
- timing(list[float]): Task specific timing for generator
- color(list[str]): Task specific color for generator
- First element is the target, second element is the fixation
- Observe that [-1] element represents the trial information
- inquiry_count(int): number of random stimuli to be created
- stim_per_inquiry(int): number of trials in a inquiry
- stim_order(StimuliOrder): ordering of stimuli in the inquiry
- inq_constants(list[str]): list of letters that should always be
- included in every inquiry. If provided, must be alp items.
- Return
- ------
- schedule_inq(tuple(
- samples[list[list[str]]]: list of inquiries
- timing(list[list[float]]): list of timings
- color(list(list[str])): list of colors)): scheduled inquiries
- """
+ """Generate RSVPKeyboard inquiry by picking n-most likely letters.
+
+ Args:
+ alp (list): Alphabet (can be arbitrary).
+ session_stimuli (np.ndarray): Quantifier metric for query selection.
+ timing (List[float], optional): Task specific timing for generator. Defaults to [1, 0.2].
+ color (List[str], optional): Task specific color for generator. Defaults to ['red', 'white'].
+ stim_number (int, optional): Number of random stimuli to be created. Defaults to 1.
+ stim_length (int, optional): Number of trials in an inquiry. Defaults to 10.
+ stim_order (StimuliOrder, optional): Ordering of stimuli. Defaults to StimuliOrder.RANDOM.
+ is_txt (bool, optional): Whether the stimuli are text. Defaults to True.
+ inq_constants (Optional[List[str]], optional): Letters that should always be included. Defaults to None.
+ Returns:
+ InquirySchedule: Scheduled inquiries with samples, timing, and color.
+ """
if len(alp) != len(session_stimuli):
raise BciPyCoreException((
f'Missing information about alphabet.'
f'len(alp):{len(alp)} and len(session_stimuli):{len(session_stimuli)} should be same!'))
-
if inq_constants and not set(inq_constants).issubset(alp):
raise BciPyCoreException('Inquiry constants must be alphabet items.')
-
# query for the best selection
query = best_selection(
alp,
session_stimuli,
stim_length,
inq_constants)
-
if stim_order == StimuliOrder.ALPHABETICAL:
query = alphabetize(query)
else:
random.shuffle(query)
-
# Init some lists to construct our stimuli with
samples, times, colors = [], [], []
for _ in range(stim_number):
-
# append a fixation cross. if not text, append path to image fixation
sample = [get_fixation(is_txt)]
-
# construct the sample from the query
sample += [i for i in query]
samples.append(sample)
-
# append timing
times.append([timing[i] for i in range(len(timing) - 1)] +
[timing[-1]] * stim_length)
-
# append colors
colors.append([color[i] for i in range(len(color) - 1)] +
[color[-1]] * stim_length)
@@ -637,37 +643,22 @@ def generate_calibration_inquiries(
target_positions: TargetPositions = TargetPositions.RANDOM,
percentage_without_target: int = 0,
is_txt: bool = True) -> InquirySchedule:
- """
- Generates inquiries with target letters in all possible positions.
-
- This function attempts to display all symbols as targets an equal number of
- times when stim_order is RANDOM. When the stim_order is ALPHABETICAL there
- is much more variation in the counts (target distribution takes priority)
- and some symbols may not appear as targets depending on the inquiry_count.
- The frequency that each symbol is displayed as a nontarget should follow a
- uniform distribution.
-
- Parameters
- ----------
- alp(list[str]): stimuli
- timing(list[float]): Task specific timing for generator.
- [target, fixation, stimuli]
- jitter(int): jitter for stimuli timing. If None, no jitter is applied.
- color(list[str]): Task specific color for generator
- [target, fixation, stimuli]
- inquiry_count(int): number of inquiries in a calibration
- stim_per_inquiry(int): number of stimuli in each inquiry
- stim_order(StimuliOrder): ordering of stimuli in the inquiry
- target_positions(TargetPositions): positioning of targets to select for the inquiries
- percentage_without_target(int): percentage of inquiries for which target letter flashed is not in inquiry
- is_txt(bool): whether the stimuli type is text. False would be an image stimuli.
-
- Return
- ------
- schedule_inq(tuple(
- samples[list[list[str]]]: list of inquiries
- timing(list[list[float]]): list of timings
- color(list(list[str])): list of colors)): scheduled inquiries
+ """Generate inquiries with target letters in all possible positions.
+
+ Args:
+ alp (List[str]): Stimuli.
+ timing (Optional[List[float]], optional): Task specific timing for generator. Defaults to None.
+ jitter (Optional[int], optional): Jitter for stimuli timing. Defaults to None.
+ color (Optional[List[str]], optional): Task specific color for generator. Defaults to None.
+ inquiry_count (int, optional): Number of inquiries in a calibration. Defaults to 100.
+ stim_per_inquiry (int, optional): Number of stimuli in each inquiry. Defaults to 10.
+ stim_order (StimuliOrder, optional): Ordering of stimuli. Defaults to StimuliOrder.RANDOM.
+ target_positions (TargetPositions, optional): Positioning of targets. Defaults to TargetPositions.RANDOM.
+ percentage_without_target (int, optional): Percentage of inquiries without a target. Defaults to 0.
+ is_txt (bool, optional): Whether the stimuli type is text. Defaults to True.
+
+ Returns:
+ InquirySchedule: Scheduled inquiries with samples, timing, and color.
"""
if timing is None:
timing = [0.5, 1, 0.2]
@@ -678,7 +669,6 @@ def generate_calibration_inquiries(
) == 3, "timing must include values for [target, fixation, stimuli]"
time_target, time_fixation, time_stim = timing
fixation = get_fixation(is_txt)
-
target_indexes = generate_target_positions(inquiry_count, stim_per_inquiry,
percentage_without_target,
target_positions)
@@ -700,15 +690,12 @@ def generate_calibration_inquiries(
next_targets=targets,
last_target=target)
samples.append([target, fixation, *inquiry])
-
times = [[
time_target, time_fixation,
*generate_inquiry_stim_timing(time_stim, stim_per_inquiry, jitter)
] for _ in range(inquiry_count)]
-
inquiry_colors = color[0:2] + [color[-1]] * stim_per_inquiry
colors = [inquiry_colors for _ in range(inquiry_count)]
-
return InquirySchedule(samples, times, colors)
@@ -717,9 +704,11 @@ def inquiry_target_counts(inquiries: List[List[str]],
"""Count the number of times each symbol was presented as a target.
Args:
- inquiries - list of inquiries where each inquiry is structured as
- [target, fixation, *stim]
- symbols - list of all possible symbols
+ inquiries (List[List[str]]): List of inquiries where each inquiry is structured as [target, fixation, *stim].
+ symbols (List[str]): List of all possible symbols.
+
+ Returns:
+ dict: Dictionary mapping each symbol to its target count.
"""
target_presentations = [inq[0] for inq in inquiries if inq[0] in inq[2:]]
counter = dict.fromkeys(symbols, 0)
@@ -730,7 +719,15 @@ def inquiry_target_counts(inquiries: List[List[str]],
def inquiry_nontarget_counts(inquiries: List[List[str]],
symbols: List[str]) -> dict:
- """Count the number of times each symbol was presented as a nontarget."""
+ """Count the number of times each symbol was presented as a nontarget.
+
+ Args:
+ inquiries (List[List[str]]): List of inquiries.
+ symbols (List[str]): List of all possible symbols.
+
+ Returns:
+ dict: Dictionary mapping each symbol to its nontarget count.
+ """
counter = dict.fromkeys(symbols, 0)
for inq in inquiries:
target, _fixation, *stimuli = inq
@@ -742,16 +739,20 @@ def inquiry_nontarget_counts(inquiries: List[List[str]],
def inquiry_stats(inquiries: List[List[str]],
symbols: List[str]) -> Dict[str, Dict[str, float]]:
- """Descriptive stats for the number of times each target and nontarget
- symbol is shown in the inquiries"""
+ """Descriptive stats for the number of times each target and nontarget symbol is shown in the inquiries.
+
+ Args:
+ inquiries (List[List[str]]): List of inquiries.
+ symbols (List[str]): List of all possible symbols.
+ Returns:
+ Dict[str, Dict[str, float]]: Dictionary with stats for target and nontarget symbols.
+ """
target_stats = dict(
Series(Counter(inquiry_target_counts(inquiries, symbols))).describe())
-
nontarget_stats = dict(
Series(Counter(inquiry_nontarget_counts(inquiries,
symbols))).describe())
-
return {
'target_symbols': target_stats,
'nontarget_symbols': nontarget_stats
@@ -762,15 +763,15 @@ def generate_inquiries(symbols: List[str], inquiry_count: int,
stim_per_inquiry: int,
stim_order: StimuliOrder) -> List[List[str]]:
"""Generate a list of inquiries. For each inquiry no symbols are repeated.
- Inquiries do not include the target or fixation. Symbols should be
- distributed uniformly across inquiries.
Args:
+ symbols (List[str]): Values from which to select.
+ inquiry_count (int): Total number of inquiries to generate.
+ stim_per_inquiry (int): Length of each inquiry.
+ stim_order (StimuliOrder): Ordering of results.
- symbols - values from which to select
- inquiry_count - total number of inquiries to generate
- stim_per_inquiry - length of each inquiry
- stim_order - ordering of results
+ Returns:
+ List[List[str]]: List of generated inquiries.
"""
return [
generate_inquiry(symbols=symbols,
@@ -781,13 +782,15 @@ def generate_inquiries(symbols: List[str], inquiry_count: int,
def generate_inquiry(symbols: List[str], length: int,
stim_order: StimuliOrder) -> List[str]:
- """Generate an inquiry from the list of symbols. No symbols are repeated
- in the output list. Output does not include the target or fixation.
+ """Generate an inquiry from the list of symbols. No symbols are repeated in the output list.
Args:
- symbols - values from which to select
- length - number of items in the return list
- stim_order - ordering of results
+ symbols (List[str]): Values from which to select.
+ length (int): Number of items in the return list.
+ stim_order (StimuliOrder): Ordering of results.
+
+ Returns:
+ List[str]: Generated inquiry.
"""
inquiry = random.sample(symbols, k=length)
if stim_order == StimuliOrder.ALPHABETICAL:
@@ -800,20 +803,17 @@ def inquiry_target(inquiry: List[str],
symbols: List[str],
next_targets: Optional[List[str]] = None,
last_target: Optional[str] = None) -> str:
- """Returns the target for the given inquiry. If the optional
- target position is not provided a target will randomly be selected from
- the list of symbols and will not be in the inquiry.
+ """Returns the target for the given inquiry.
Args:
- inquiry - list of symbols to be presented
- target_position - optional position within the list of the target sym
- symbols - used if position is not provided to select a random symbol
- as the target.
- next_targets - list of targets from which to select
- last_target - target from the previous inquiry; used to avoid selecting
- the same target consecutively.
-
- Returns target symbol
+ inquiry (List[str]): List of symbols to be presented.
+ target_position (Optional[int]): Optional position within the list of the target symbol.
+ symbols (List[str]): Used if position is not provided to select a random symbol as the target.
+ next_targets (Optional[List[str]], optional): List of targets from which to select. Defaults to None.
+ last_target (Optional[str], optional): Target from the previous inquiry. Defaults to None.
+
+ Returns:
+ str: Target symbol.
"""
if target_position is None:
return random.choice(list(set(symbols) - set(inquiry)))
@@ -854,11 +854,12 @@ def generate_inquiry_stim_timing(time_stim: float, length: int,
"""Generate stimuli timing values for a given inquiry.
Args:
- time_stim: seconds to display each stimuli
- length: Number of timings to generate
- jitter: whether the timing should be jittered.
+ time_stim (float): Seconds to display each stimulus.
+ length (int): Number of timings to generate.
+ jitter (bool): Whether the timing should be jittered.
- Returns list of times (in seconds)
+ Returns:
+ List[float]: List of times (in seconds).
"""
if jitter:
return jittered_timing(time_stim, jitter, length)
@@ -868,10 +869,15 @@ def generate_inquiry_stim_timing(time_stim: float, length: int,
def jittered_timing(time: float, jitter: float,
stim_count: int) -> List[float]:
- """Jittered timing.
+ """Generate a list of jittered timing values for stimuli.
+
+ Args:
+ time (float): Base time for each stimulus.
+ jitter (float): Jitter to apply.
+ stim_count (int): Number of stimuli.
- Using a base time and a jitter, generate a list (with length stim_count) of
- timing that is uniformly distributed.
+ Returns:
+ List[float]: List of jittered timing values.
"""
assert time > jitter, (
f'Jitter timing [{jitter}] must be less than stimuli timing =[{time}] in the inquiry.'
@@ -883,15 +889,14 @@ def jittered_timing(time: float, jitter: float,
def compute_counts(inquiry_count: int,
percentage_without_target: int) -> Tuple[int, int]:
- """Determine the number of inquiries that should display targets and the
- number that should not.
+ """Determine the number of inquiries that should display targets and the number that should not.
Args:
- inquiry_count: Number of inquiries in calibration
- percentage_without_target: percentage of inquiries for which
- target letter flashed is not in inquiry
+ inquiry_count (int): Number of inquiries in calibration.
+ percentage_without_target (int): Percentage of inquiries without a target.
- Returns tuple of (target_count, no_target_count)
+ Returns:
+ Tuple[int, int]: Tuple of (target_count, no_target_count).
"""
no_target_count = int(inquiry_count * (percentage_without_target / 100))
target_count = inquiry_count - no_target_count
@@ -901,17 +906,16 @@ def compute_counts(inquiry_count: int,
def generate_target_positions(inquiry_count: int, stim_per_inquiry: int,
percentage_without_target: int,
distribution: TargetPositions) -> List[int]:
- """
- Generates target positions distributed according to the provided parameter.
+ """Generate target positions distributed according to the provided parameter.
Args:
- inquiry_count: Number of inquiries in calibration
- stim_per_inquiry: Number of stimuli in each inquiry
- percentage_without_target: percentage of inquiries for which
- target letter flashed is not in inquiry
- distribution: specifies how targets should be distributed
+ inquiry_count (int): Number of inquiries in calibration.
+ stim_per_inquiry (int): Number of stimuli in each inquiry.
+ percentage_without_target (int): Percentage of inquiries without a target.
+ distribution (TargetPositions): Specifies how targets should be distributed.
- Returns list of indexes
+ Returns:
+ List[int]: List of indexes for target positions.
"""
if distribution is TargetPositions.DISTRIBUTED:
return distributed_target_positions(inquiry_count, stim_per_inquiry,
@@ -922,65 +926,51 @@ def generate_target_positions(inquiry_count: int, stim_per_inquiry: int,
def distributed_target_positions(inquiry_count: int, stim_per_inquiry: int,
percentage_without_target: int) -> list:
- """Distributed Target Positions.
-
- Generates evenly distributed target positions, including target letter
- not flashed at all, and shuffles them.
+ """Generate evenly distributed target positions, including target letter not flashed at all, and shuffle them.
Args:
- inquiry_count(int): Number of inquiries in calibration
- stim_per_inquiry(int): Number of stimuli in each inquiry
- percentage_without_target(int): percentage of inquiries for which
- target letter flashed is not in inquiry
+ inquiry_count (int): Number of inquiries in calibration.
+ stim_per_inquiry (int): Number of stimuli in each inquiry.
+ percentage_without_target (int): Percentage of inquiries without a target.
- Return distributed_target_positions(list): targets: array of target
- indexes to be chosen
+ Returns:
+ list: Targets array of target indexes to be chosen.
"""
-
targets = []
-
# find number of target and no_target inquiries
target_count, no_target_count = compute_counts(inquiry_count,
percentage_without_target)
-
# find number each target position is repeated, and remaining number
num_pos = (int)(target_count / stim_per_inquiry)
num_rem_pos = (target_count % stim_per_inquiry)
-
# add correct number of None's for nontarget inquiries
targets = [NO_TARGET_INDEX] * no_target_count
-
# add distributed list of target positions
targets.extend(list(range(stim_per_inquiry)) * num_pos)
-
# pick leftover positions randomly
rem_pos = list(range(stim_per_inquiry))
random.shuffle(rem_pos)
rem_pos = rem_pos[0:num_rem_pos]
targets.extend(rem_pos)
-
# shuffle targets
random.shuffle(targets)
-
return targets
def random_target_positions(inquiry_count: int, stim_per_inquiry: int,
percentage_without_target: int) -> list:
- """Generates randomly distributed target positions, including target letter
- not flashed at all, and shuffles them.
+ """Generate randomly distributed target positions, including target letter not flashed at all, and shuffle them.
Args:
- inquiry_count(int): Number of inquiries in calibration
- stim_per_inquiry(int): Number of stimuli in each inquiry
- percentage_without_target(int): percentage of inquiries for which
- target letter flashed is not in inquiry
+ inquiry_count (int): Number of inquiries in calibration.
+ stim_per_inquiry (int): Number of stimuli in each inquiry.
+ percentage_without_target (int): Percentage of inquiries without a target.
- Return list of target indexes to be chosen
+ Returns:
+ list: List of target indexes to be chosen.
"""
target_count, no_target_count = compute_counts(inquiry_count,
percentage_without_target)
-
target_indexes = [NO_TARGET_INDEX] * no_target_count
target_indexes.extend(
random.choices(range(stim_per_inquiry), k=target_count))
@@ -990,20 +980,18 @@ def random_target_positions(inquiry_count: int, stim_per_inquiry: int,
def generate_targets(symbols: List[str], inquiry_count: int,
percentage_without_target: int) -> List[str]:
- """Generates list of target symbols. Generates an equal number of each
- target. The resulting list may be less than the inquiry_count. Used for
- sampling without replacement to get approximately equal numbers of each
- target.
+ """Generate a list of target symbols for calibration inquiries.
Args:
- symbols:
- inquiry_count: number of inquiries in calibration
- percentage_without_target: percentage of inquiries for which
- target letter flashed is not in inquiry
+ symbols (List[str]): List of possible symbols.
+ inquiry_count (int): Number of inquiries in calibration.
+ percentage_without_target (int): Percentage of inquiries without a target.
+
+ Returns:
+ List[str]: List of target symbols.
"""
target_count, no_target_count = compute_counts(inquiry_count,
percentage_without_target)
-
# each symbol should appear at least once
symbol_count = int(target_count / len(symbols)) or 1
targets = symbols * symbol_count
@@ -1012,19 +1000,13 @@ def generate_targets(symbols: List[str], inquiry_count: int,
def target_index(inquiry: List[str]) -> Optional[int]:
- """Given an inquiry, return the index of the target within the choices and
- None if the target is not included as a choice.
-
- Parameters
- ----------
- inquiry - list of [target, fixation, *choices]
-
- >>> inquiry = ['T', '+', 'G', 'J', 'K', 'L', 'M', 'Q', 'T', 'V', 'X', '<']
- >>> target_index(inquiry)
- 6
- >>> inquiry = ['A', '+', 'G', 'J', 'K', 'L', 'M', 'Q', 'T', 'V', 'X', '<']
- >>> target_index(inquiry)
- None
+ """Return the index of the target within the choices, or None if not present.
+
+ Args:
+ inquiry (List[str]): List of [target, fixation, *choices].
+
+ Returns:
+ Optional[int]: Index of the target in choices, or None if not present.
"""
assert len(inquiry) > 3, "Not enough choices"
target, _fixation, *choices = inquiry
@@ -1035,19 +1017,15 @@ def target_index(inquiry: List[str]) -> Optional[int]:
def get_task_info(experiment_length: int, task_color: str) -> Tuple[List[str], List[str]]:
- """Get Task Info.
+ """Generate fixed RSVPKeyboard task text and color information for display.
- Generates fixed RSVPKeyboard task text and color information for
- display.
Args:
- experiment_length(int): Number of inquiries for the experiment
- task_color(str): Task information display color
+ experiment_length (int): Number of inquiries for the experiment.
+ task_color (str): Task information display color.
- Return get_task_info((tuple): task_text: array of task text to display
- task_color: array of colors for the task text
- )
+ Returns:
+ Tuple[List[str], List[str]]: Tuple of task text and color arrays.
"""
-
# Do list comprehensions to get the arrays for the task we need.
task_text_list = ['%s/%s' % (stim + 1, experiment_length)
for stim in range(experiment_length)]
@@ -1057,11 +1035,16 @@ def get_task_info(experiment_length: int, task_color: str) -> Tuple[List[str], L
def resize_image(image_path: str, screen_size: tuple, sti_height: float) -> Tuple[float, float]:
- """Resize Image.
+ """Return the width and height that a given image should be displayed at given the screen size and stimuli height.
+
+ Args:
+ image_path (str): Path to the image file.
+ screen_size (tuple): Screen size as (width, height).
+ sti_height (float): Desired stimuli height.
- Returns the width and height that a given image should be displayed at
- given the screen size, size of the original image, and stimuli height
- parameter"""
+ Returns:
+ Tuple[float, float]: Width and height for displaying the image.
+ """
# Retrieve image width and height
with Image.open(image_path) as pillow_image:
image_width, image_height = pillow_image.size
@@ -1094,43 +1077,35 @@ def play_sound(sound_file_path: str,
experiment_clock=None,
trigger_name: Optional[str] = None,
timing: list = []) -> list:
- """Play Sound.
-
- Using soundevice and soundfile, play a sound giving options to buffer times between
- loading sound into memory and after playing. If desired, marker writers or list based
- timing with psychopy clocks may be passed and sound timing returned.
-
-
- PARAMETERS
- ----------
- :param: sound_file_path
- :param: dtype: type of sound ex. float32.
- :param: track_timing: whether or not to track timing of sound playin
- :param: sound_callback: trigger based callback (see MarkerWriter and NullMarkerWriter)
- :param: sound_load_buffer_time: time to wait after loading file before playing
- :param: experiment_clock: psychopy clock to get time of sound stimuli
- :param: trigger_name: name of the sound trigger
- :param: timing: list of triggers in the form of trigger name, trigger timing
- :resp: timing
- """
+ """Play a sound file and optionally track timing and triggers.
+
+ Args:
+ sound_file_path (str): Path to the sound file.
+ dtype (str, optional): Type of sound (e.g., 'float32'). Defaults to 'float32'.
+ track_timing (bool, optional): Whether to track timing of sound playing. Defaults to False.
+ sound_callback (optional): Callback for sound triggers. Defaults to None.
+ sound_load_buffer_time (float, optional): Time to wait after loading file before playing. Defaults to 0.5.
+ experiment_clock (optional): Clock to get time of sound stimuli. Defaults to None.
+ trigger_name (Optional[str], optional): Name of the sound trigger. Defaults to None.
+ timing (list, optional): List of triggers in the form of trigger name, trigger timing. Defaults to [].
+ Returns:
+ list: Timing information for sound triggers.
+ """
try:
# load in the sound file and wait some time before playing
data, fs = sf.read(sound_file_path, dtype=dtype)
core.wait(sound_load_buffer_time)
-
except Exception as e:
error_message = f'Sound file could not be found or initialized. \n Exception={e}'
log.exception(error_message)
raise BciPyCoreException(error_message)
-
# if timing is wanted, get trigger timing for this sound stimuli
if track_timing:
# if there is a timing callback for sound, evoke it
if sound_callback is not None:
sound_callback(experiment_clock, trigger_name)
timing.append([trigger_name, experiment_clock.getTime()])
-
# play our loaded sound and wait for some time before it's finished
# NOTE: there is a measurable delay for calling sd.play. (~ 0.1 seconds;
# which I believe happens prior to the sound playing).
@@ -1142,15 +1117,13 @@ def play_sound(sound_file_path: str,
def soundfiles(directory: str) -> Iterator[str]:
- """Creates a generator that cycles through sound files (.wav) in a
- directory and returns the path to next sound file on each iteration.
+ """Return an iterator cycling through .wav files in a directory.
+
+ Args:
+ directory (str): Path to the directory containing .wav files.
- Parameters:
- -----------
- directory - path to the directory which contains .wav files
Returns:
- --------
- iterator that infinitely cycles through the filenames.
+ Iterator[str]: Iterator that infinitely cycles through the filenames.
"""
if not path.isdir(directory):
error_message = f'Invalid directory=[{directory}] for sound files.'
@@ -1162,9 +1135,13 @@ def soundfiles(directory: str) -> Iterator[str]:
def get_fixation(is_txt: bool) -> str:
- """Get Fixation.
+ """Return the correct stimulus fixation given the type (text or image).
+
+ Args:
+ is_txt (bool): Whether the fixation is text or image.
- Return the correct stimulus fixation given the type (text or image).
+ Returns:
+ str: Fixation stimulus (text or image path).
"""
if is_txt:
return DEFAULT_TEXT_FIXATION
diff --git a/bcipy/core/symbols.py b/bcipy/core/symbols.py
index dca9feb4c..1b642e4fe 100644
--- a/bcipy/core/symbols.py
+++ b/bcipy/core/symbols.py
@@ -1,25 +1,35 @@
"""Defines helper methods and variables related to input symbols"""
import os
from string import ascii_uppercase
-from typing import Callable
+from typing import Any, Callable, List, Optional
SPACE_CHAR = '_'
BACKSPACE_CHAR = '<'
-def alphabet(parameters=None, include_path=True, backspace=BACKSPACE_CHAR, space=SPACE_CHAR):
- """Alphabet.
+def alphabet(parameters: Optional[Any] = None, include_path: bool = True,
+ backspace: str = BACKSPACE_CHAR, space: str = SPACE_CHAR) -> List[str]:
+ """Standardizes and returns the alphabet symbols used in BciPy.
- Function used to standardize the symbols we use as alphabet.
+ The symbols can either be text (uppercase ASCII letters, backspace, and space)
+ or paths to image files, depending on the `parameters` and `is_txt_stim` setting.
- Returns
- -------
- array of letters.
+ Args:
+ parameters (Optional[Any], optional): A dictionary-like object containing configuration
+ parameters, specifically 'is_txt_stim' and 'path_to_presentation_images'.
+ Defaults to None.
+ include_path (bool, optional): If True and image stimuli are used, returns full paths to images.
+ If False, returns just the image filenames without extensions. Defaults to True.
+ backspace (str, optional): The character representing backspace. Defaults to `BACKSPACE_CHAR`.
+ space (str, optional): The character representing space. Defaults to `SPACE_CHAR`.
+
+ Returns:
+ List[str]: A list of alphabet symbols (either letters or image paths).
"""
if parameters and not parameters['is_txt_stim']:
# construct an array of paths to images
path = parameters['path_to_presentation_images']
- stimulus_array = []
+ stimulus_array: List[str] = []
for stimulus_filename in sorted(os.listdir(path)):
# PLUS.png is reserved for the fixation symbol
if stimulus_filename.endswith(
@@ -36,9 +46,21 @@ def alphabet(parameters=None, include_path=True, backspace=BACKSPACE_CHAR, space
def qwerty_order(is_txt_stim: bool = True,
space: str = SPACE_CHAR,
- backspace: str = BACKSPACE_CHAR) -> Callable:
- """Returns a function that can be used to sort the alphabet symbols
- in QWERTY order. Note that sorting only works for text stim.
+ backspace: str = BACKSPACE_CHAR) -> Callable[[str], int]:
+ """Returns a function that can be used to sort alphabet symbols in QWERTY order.
+
+ Note that sorting only works for text stimuli.
+
+ Args:
+ is_txt_stim (bool, optional): If True, indicates text stimuli. Defaults to True.
+ space (str, optional): The character representing space. Defaults to `SPACE_CHAR`.
+ backspace (str, optional): The character representing backspace. Defaults to `BACKSPACE_CHAR`.
+
+ Returns:
+ Callable[[str], int]: A function that takes a symbol string and returns its QWERTY index.
+
+ Raises:
+ NotImplementedError: If `is_txt_stim` is False, as QWERTY ordering is not implemented for images.
"""
if not is_txt_stim:
raise NotImplementedError('QWERTY ordering not implemented for images')
@@ -52,9 +74,21 @@ def qwerty_order(is_txt_stim: bool = True,
def frequency_order(
is_txt_stim: bool = True,
space: str = SPACE_CHAR,
- backspace: str = BACKSPACE_CHAR) -> Callable:
- """Returns a function that can be used to sort the alphabet symbols
- in most frequently used order in the English language.
+ backspace: str = BACKSPACE_CHAR) -> Callable[[str], int]:
+ """Returns a function that can be used to sort alphabet symbols by frequency of use in English."
+
+ Note that sorting only works for text stimuli.
+
+ Args:
+ is_txt_stim (bool, optional): If True, indicates text stimuli. Defaults to True.
+ space (str, optional): The character representing space. Defaults to `SPACE_CHAR`.
+ backspace (str, optional): The character representing backspace. Defaults to `BACKSPACE_CHAR`.
+
+ Returns:
+ Callable[[str], int]: A function that takes a symbol string and returns its frequency order index.
+
+ Raises:
+ NotImplementedError: If `is_txt_stim` is False, as frequency ordering is not implemented for images.
"""
if not is_txt_stim:
raise NotImplementedError(
diff --git a/bcipy/core/tests/test_list.py b/bcipy/core/tests/test_list.py
index 88dcbd3e7..8759b2795 100644
--- a/bcipy/core/tests/test_list.py
+++ b/bcipy/core/tests/test_list.py
@@ -95,7 +95,8 @@ def test_pairwise(self):
"""Test pairwise iterator"""
iterable = 'ABCDEFG'
response = pairwise(iterable)
- expected = [('A', 'B'), ('B', 'C'), ('C', 'D'), ('D', 'E'), ('E', 'F'), ('F', 'G')]
+ expected = [('A', 'B'), ('B', 'C'), ('C', 'D'),
+ ('D', 'E'), ('E', 'F'), ('F', 'G')]
self.assertListEqual(expected, list(response))
diff --git a/bcipy/core/tests/test_raw_data.py b/bcipy/core/tests/test_raw_data.py
index 6d8294be0..2de57891e 100644
--- a/bcipy/core/tests/test_raw_data.py
+++ b/bcipy/core/tests/test_raw_data.py
@@ -312,7 +312,8 @@ def test_data_by_channel_applies_transformation(self):
transform = mock()
# note data here should be returned as a nd.array. for mocking we don't care as much
- when(RawData).apply_transform(any(), transform).thenReturn((data, self.sample_rate))
+ when(RawData).apply_transform(
+ any(), transform).thenReturn((data, self.sample_rate))
resp, fs = data.by_channel(transform=transform)
self.assertEqual(self.sample_rate, fs)
@@ -369,8 +370,10 @@ def test_data_by_channel_map_applies_transformation(self):
transform = mock()
expected_output, expected_fs = data.by_channel()
# note data here should be returned as a nd.array. for mocking we don't care as much
- when(RawData).by_channel(transform).thenReturn((expected_output, expected_fs))
- _, channels, fs = data.by_channel_map(channel_map=channel_map, transform=transform)
+ when(RawData).by_channel(transform).thenReturn(
+ (expected_output, expected_fs))
+ _, channels, fs = data.by_channel_map(
+ channel_map=channel_map, transform=transform)
self.assertEqual(expected_fs, fs)
self.assertEqual(expected_channels, channels)
@@ -389,7 +392,8 @@ def test_get_1020_channels(self):
def test_get_1020_channel_map(self):
"""Tests that the 10-20 channel map is correctly generated."""
# all but the last channel are valid 10-20 channels
- channels = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'invalid']
+ channels = ['Fp1', 'Fp2', 'F3', 'F4', 'C3',
+ 'C4', 'P3', 'P4', 'O1', 'invalid']
channel_map = get_1020_channel_map(channels)
self.assertEqual(10, len(channel_map))
self.assertEqual(0, channel_map[-1])
diff --git a/bcipy/core/tests/test_report.py b/bcipy/core/tests/test_report.py
index 4837bbdcc..a181d8c7f 100644
--- a/bcipy/core/tests/test_report.py
+++ b/bcipy/core/tests/test_report.py
@@ -61,7 +61,8 @@ def test_add_section(self):
self.assertEqual(report.sections, [report_section])
another_report_section = SessionReportSection(summary=summary)
report.add(another_report_section)
- self.assertEqual(report.sections, [report_section, another_report_section])
+ self.assertEqual(report.sections, [
+ report_section, another_report_section])
def test_save(self):
report = Report(self.temp_dir)
@@ -69,12 +70,14 @@ def test_save(self):
report_section = SessionReportSection(summary)
report.add(report_section)
report.save()
- self.assertTrue(os.path.exists(os.path.join(self.temp_dir, report.name)))
+ self.assertTrue(os.path.exists(
+ os.path.join(self.temp_dir, report.name)))
def test_save_no_sections(self):
report = Report(self.temp_dir)
report.save()
- self.assertTrue(os.path.exists(os.path.join(self.temp_dir, report.name)))
+ self.assertTrue(os.path.exists(
+ os.path.join(self.temp_dir, report.name)))
def test_complile_adds_section_and_header(self):
report = Report(self.temp_dir)
diff --git a/bcipy/core/tests/test_stimuli.py b/bcipy/core/tests/test_stimuli.py
index 3c5861e28..5f53a289d 100644
--- a/bcipy/core/tests/test_stimuli.py
+++ b/bcipy/core/tests/test_stimuli.py
@@ -500,11 +500,13 @@ def test_random_target_positions_all_nontargets(self):
def test_generate_targets(self):
"""Test target generation"""
symbols = ['A', 'B', 'C', 'D']
- targets = generate_targets(symbols, inquiry_count=9, percentage_without_target=0)
+ targets = generate_targets(
+ symbols, inquiry_count=9, percentage_without_target=0)
self.assertEqual(len(targets), 8)
self.assertTrue(all(val == 2 for val in Counter(targets).values()))
- targets = generate_targets(symbols, inquiry_count=9, percentage_without_target=50)
+ targets = generate_targets(
+ symbols, inquiry_count=9, percentage_without_target=50)
self.assertEqual(len(targets), 4)
self.assertTrue(all(val == 1 for val in Counter(targets).values()))
@@ -536,7 +538,8 @@ def test_inquiry_target_with_none_position(self):
target = inquiry_target(inquiry, None, symbols, next_targets)
self.assertTrue(target not in inquiry)
self.assertTrue(target in symbols)
- self.assertSequenceEqual(inquiry, ['C', 'A', 'F', 'E'], 'inquiry should not have changed')
+ self.assertSequenceEqual(
+ inquiry, ['C', 'A', 'F', 'E'], 'inquiry should not have changed')
def test_inquiry_target_missing(self):
"""Test inquiry target where none of the next_targets are present in
@@ -550,8 +553,10 @@ def test_inquiry_target_missing(self):
next_targets=next_targets)
self.assertTrue(target in inquiry)
self.assertTrue(target not in next_targets)
- self.assertSequenceEqual(inquiry, ['C', 'A', 'F', 'E'], 'inquiry should not have changed')
- self.assertSequenceEqual(next_targets, ['Q', 'D'], 'next_targets should not have changed')
+ self.assertSequenceEqual(
+ inquiry, ['C', 'A', 'F', 'E'], 'inquiry should not have changed')
+ self.assertSequenceEqual(
+ next_targets, ['Q', 'D'], 'next_targets should not have changed')
def test_inquiry_target_no_targets(self):
"""Test inquiry_target when no next_targets are provided"""
@@ -562,8 +567,10 @@ def test_inquiry_target_no_targets(self):
target_position=0,
symbols=symbols,
next_targets=next_targets)
- self.assertEqual(target, 'C', 'should have used the target_position to get the target')
- self.assertSequenceEqual(inquiry, ['C', 'A', 'F', 'E'], 'inquiry should not have changed')
+ self.assertEqual(
+ target, 'C', 'should have used the target_position to get the target')
+ self.assertSequenceEqual(
+ inquiry, ['C', 'A', 'F', 'E'], 'inquiry should not have changed')
def test_inquiry_target_last_target(self):
"""Test inquiry target behavior."""
@@ -776,7 +783,8 @@ def test_best_case_inq_gen_is_random(self):
is_txt=True,
inq_constants=['<'])
samps.add(tuple(samples[0]))
- self.assertTrue(len(samps) > 1, '`best_case_rsvp_inq_gen` Should produce random results')
+ self.assertTrue(
+ len(samps) > 1, '`best_case_rsvp_inq_gen` Should produce random results')
class TestJitteredTiming(unittest.TestCase):
@@ -894,7 +902,8 @@ def test_trial_reshaper_with_no_channel_map(self):
poststimulus_length=trial_length_s
)
trial_length_samples = int(sample_rate * trial_length_s)
- expected_shape = (self.channel_number, len(self.target_info), trial_length_samples)
+ expected_shape = (self.channel_number, len(
+ self.target_info), trial_length_samples)
self.assertTrue(np.all(labels == [1, 0, 0]))
self.assertTrue(reshaped_trials.shape == expected_shape)
@@ -972,7 +981,8 @@ def test_inquiry_reshaper_with_no_channel_map(self):
channel_map=None,
poststimulus_length=self.trial_length
)
- expected_shape = (self.n_channel, self.n_inquiry, self.samples_per_inquiry)
+ expected_shape = (self.n_channel, self.n_inquiry,
+ self.samples_per_inquiry)
self.assertTrue(reshaped_data.shape == expected_shape)
self.assertTrue(np.all(labels == self.true_labels))
diff --git a/bcipy/core/tests/test_triggers.py b/bcipy/core/tests/test_triggers.py
index 00ddccf3d..9f33fd0b1 100644
--- a/bcipy/core/tests/test_triggers.py
+++ b/bcipy/core/tests/test_triggers.py
@@ -170,8 +170,10 @@ def setUp(self, mock_file):
self.flush = FlushFrequency.END
self.file = f'{self.path_name}/{self.file_name}.txt'
# with patch('builtins.open', mock_open(read_data='data')) as _:
- self.handler = TriggerHandler(self.path_name, self.file_name, self.flush)
- self.mock_file.assert_called_once_with(self.file, 'w+', encoding=self.handler.encoding)
+ self.handler = TriggerHandler(
+ self.path_name, self.file_name, self.flush)
+ self.mock_file.assert_called_once_with(
+ self.file, 'w+', encoding=self.handler.encoding)
def tearDown(self):
unstub()
@@ -179,7 +181,8 @@ def tearDown(self):
def test_file_exist_exception(self):
with open(self.file, 'w+', encoding=self.handler.encoding) as _:
with self.assertRaises(Exception):
- TriggerHandler(self.path_name, self.file_name, FlushFrequency.END)
+ TriggerHandler(self.path_name, self.file_name,
+ FlushFrequency.END)
os.remove(self.file)
def test_add_triggers_returns_list_of_triggers(self):
@@ -478,7 +481,8 @@ def test_starting_offsets_by_device_defaults(self):
"""Test default values"""
triggers = read_data('''J prompt 6.15
+ fixation 8.11'''.split('\n'))
- offsets = starting_offsets_by_device(triggers, device_types=['EEG', 'EYETRACKER'])
+ offsets = starting_offsets_by_device(
+ triggers, device_types=['EEG', 'EYETRACKER'])
self.assertEqual(len(offsets), 2)
self.assertEqual(offsets['EEG'].time, 0.0)
self.assertEqual(offsets['EYETRACKER'].time, 0.0)
diff --git a/bcipy/core/triggers.py b/bcipy/core/triggers.py
index 906982ee1..95363f820 100644
--- a/bcipy/core/triggers.py
+++ b/bcipy/core/triggers.py
@@ -1,3 +1,8 @@
+"""Trigger utilities for BciPy core.
+
+This module provides classes and functions for managing triggers and calibration events in BciPy experiments.
+"""
+
import logging
import os
from enum import Enum
@@ -42,31 +47,50 @@ class CalibrationType(Enum):
IMAGE = 'image'
@classmethod
- def list(cls):
- """Returns all enum values as a list"""
+ def list(cls) -> List[str]:
+ """Returns all enum values as a list.
+
+ Returns:
+ List[str]: List of all enum values.
+ """
return list(map(lambda c: c.value, cls))
class TriggerCallback:
+ """Callback handler for trigger events.
+
+ Attributes:
+ timing (Optional[Tuple[str, float]]): Timing information for the trigger.
+ first_time (bool): Flag indicating if this is the first trigger.
+ """
timing: Optional[Tuple[str, float]] = None
first_time: bool = True
def callback(self, clock: Clock, stimuli: str) -> None:
+ """Callback function for trigger events.
+
+ Args:
+ clock (Clock): Clock instance for timing.
+ stimuli (str): Stimulus identifier.
+ """
if self.first_time:
self.timing = (stimuli, clock.getTime())
self.first_time = False
- def reset(self):
+ def reset(self) -> None:
+ """Reset the callback state."""
self.timing = None
self.first_time = True
-def _calibration_trigger(experiment_clock: Clock,
- trigger_type: str = CalibrationType.TEXT.value,
- trigger_name: str = 'calibration',
- trigger_time: float = 1,
- display=None,
- on_trigger=None) -> Tuple[str, float]:
+def _calibration_trigger(
+ experiment_clock: Clock,
+ trigger_type: str = CalibrationType.TEXT.value,
+ trigger_name: str = 'calibration',
+ trigger_time: float = 1,
+ display: Optional[visual.Window] = None,
+ on_trigger: Optional[Callable[[str], None]] = None
+) -> Tuple[str, float]:
"""Calibration Trigger.
Outputs triggers for the purpose of calibrating data and stimuli.
@@ -74,20 +98,22 @@ def _calibration_trigger(experiment_clock: Clock,
code aims to operationalize the approach to finding the correct DAQ samples in
relation to our trigger code.
- PARAMETERS
- ---------
- experiment_clock(clock): clock with getTime() method, which is used in the code
- to report timing of stimuli
- trigger_type(string): type of trigger that is desired (text, image, etc)
- trigger_name(string): name of the trigger used for callbacks / labeling
- trigger_time(float): time to display the trigger. Can also be used as a buffer.
- display(DisplayWindow): a window that can display stimuli. Currently, a Psychopy window.
- on_trigger(function): optional callback; if present gets called
- when the calibration trigger is fired; accepts a single
- parameter for the timing information.
- Return:
- timing(array): timing values for the calibration triggers to be written to trigger file or
- used to calculate offsets.
+ Args:
+ experiment_clock (Clock): Clock with getTime() method, which is used in the code
+ to report timing of stimuli.
+ trigger_type (str): Type of trigger that is desired (text, image, etc).
+ trigger_name (str): Name of the trigger used for callbacks / labeling.
+ trigger_time (float): Time to display the trigger. Can also be used as a buffer.
+ display (Optional[visual.Window]): A window that can display stimuli. Currently, a Psychopy window.
+ on_trigger (Optional[Callable[[str], None]]): Optional callback; if present gets called
+ when the calibration trigger is fired; accepts a single parameter for the timing information.
+
+ Returns:
+ Tuple[str, float]: Timing values for the calibration triggers to be written to trigger file or
+ used to calculate offsets.
+
+ Raises:
+ BciPyCoreException: If trigger type is invalid or display is required but not provided.
"""
trigger_callback = TriggerCallback()
@@ -109,9 +135,11 @@ def _calibration_trigger(experiment_clock: Clock,
pos=(-.5, -.5),
mask=None,
ori=0.0)
- calibration_box.size = resize_image(CALIBRATION_IMAGE_PATH, display.size, 0.75)
+ calibration_box.size = resize_image(
+ CALIBRATION_IMAGE_PATH, display.size, 0.75)
- display.callOnFlip(trigger_callback.callback, experiment_clock, trigger_name)
+ display.callOnFlip(trigger_callback.callback,
+ experiment_clock, trigger_name)
if on_trigger is not None:
display.callOnFlip(on_trigger, trigger_name)
@@ -137,7 +165,14 @@ def _calibration_trigger(experiment_clock: Clock,
def trigger_durations(params: Parameters) -> Dict[str, float]:
- """Duration for each type of trigger given in seconds."""
+ """Get duration for each type of trigger given in seconds.
+
+ Args:
+ params (Parameters): Parameters containing timing information.
+
+ Returns:
+ Dict[str, float]: Dictionary mapping trigger types to their durations in seconds.
+ """
return {
'offset': 0.0,
'preview': params['preview_inquiry_length'],
@@ -149,9 +184,7 @@ def trigger_durations(params: Parameters) -> Dict[str, float]:
class TriggerType(Enum):
- """
- Enum for the primary types of Triggers.
- """
+ """Enum for the primary types of Triggers."""
NONTARGET = "nontarget"
TARGET = "target"
@@ -165,54 +198,86 @@ class TriggerType(Enum):
@classmethod
def list(cls) -> List[str]:
- """Returns all enum values as a list"""
+ """Returns all enum values as a list.
+
+ Returns:
+ List[str]: List of all enum values.
+ """
return list(map(lambda c: c.value, cls))
@classmethod
def pre_fixation(cls) -> List['TriggerType']:
"""Returns the subset of TriggerTypes that occur before and including
- the FIXATION trigger."""
+ the FIXATION trigger.
+
+ Returns:
+ List[TriggerType]: List of trigger types that occur before fixation.
+ """
return [
TriggerType.FIXATION, TriggerType.PROMPT, TriggerType.SYSTEM,
TriggerType.OFFSET
]
def __str__(self) -> str:
+ """String representation of the trigger type.
+
+ Returns:
+ str: String representation of the trigger type.
+ """
return f'{self.value}'
class Trigger(NamedTuple):
- """
- Object that encompasses data for a single trigger instance.
+ """Object that encompasses data for a single trigger instance.
+
+ Attributes:
+ label (str): Label for the trigger.
+ type (TriggerType): Type of the trigger.
+ time (float): Timestamp of the trigger.
"""
label: str
type: TriggerType
time: float
- def __repr__(self):
+ def __repr__(self) -> str:
+ """String representation of the trigger.
+
+ Returns:
+ str: String representation of the trigger.
+ """
return f'Trigger: label=[{self.label}] type=[{self.type}] time=[{self.time}]'
- def with_offset(self, offset: float):
- """Construct a copy of this Trigger with the offset adjusted."""
+ def with_offset(self, offset: float) -> 'Trigger':
+ """Construct a copy of this Trigger with the offset adjusted.
+
+ Args:
+ offset (float): Offset to apply to the trigger time.
+
+ Returns:
+ Trigger: New trigger instance with adjusted time.
+ """
return Trigger(self.label, self.type, self.time + offset)
@classmethod
- def from_list(cls, lst: List[str]):
+ def from_list(cls, lst: List[str]) -> 'Trigger':
"""Constructs a Trigger from a serialized representation.
- Parameters
- ----------
- lst - serialized representation [label, type, stamp]
+ Args:
+ lst (List[str]): Serialized representation [label, type, stamp].
+
+ Returns:
+ Trigger: New trigger instance.
+
+ Raises:
+ AssertionError: If input list does not have exactly 3 elements.
"""
assert len(lst) == 3, "Input must have a label, type, and stamp"
return cls(lst[0], TriggerType(lst[1]), float(lst[2]))
class FlushFrequency(Enum):
- """
- Enum that defines how often list of Triggers will be written and dumped.
- """
+ """Enum that defines how often list of Triggers will be written and dumped."""
EVERY = "flush after every trigger addition"
END = "flush at end of session"
@@ -221,14 +286,15 @@ class FlushFrequency(Enum):
def read_data(lines: Iterable[str]) -> List[Trigger]:
"""Read raw trigger data from the given source.
- Parameters
- ----------
- data - iterable object where each item is a str with data for a single
+ Args:
+ lines (Iterable[str]): Iterable object where each item is a str with data for a single
trigger.
- Returns
- -------
- list of all Triggers in the data.
+ Returns:
+ List[Trigger]: List of all Triggers in the data.
+
+ Raises:
+ BciPyCoreException: If there is an error reading a trigger from any line.
"""
triggers = []
for i, line in enumerate(lines):
@@ -244,6 +310,13 @@ def read_data(lines: Iterable[str]) -> List[Trigger]:
def offset_label(device_type: Optional[str] = None, prefix: str = 'starting_offset') -> str:
"""Compute the offset label for the given device.
+
+ Args:
+ device_type (Optional[str]): Type of device. If None or 'EEG', returns default prefix.
+ prefix (str): Prefix for the offset label. Defaults to 'starting_offset'.
+
+ Returns:
+ str: Offset label for the device.
"""
if not device_type or device_type == 'EEG':
return prefix
@@ -251,7 +324,18 @@ def offset_label(device_type: Optional[str] = None, prefix: str = 'starting_offs
def offset_device(label: str, prefix: str = 'starting_offset') -> str:
- """Given an label, determine the device type"""
+ """Given a label, determine the device type.
+
+ Args:
+ label (str): Label to parse.
+ prefix (str): Expected prefix of the label. Defaults to 'starting_offset'.
+
+ Returns:
+ str: Device type extracted from the label.
+
+ Raises:
+ AssertionError: If label does not start with the given prefix.
+ """
assert label.startswith(
prefix), "Label must start with the given prefix"
try:
@@ -262,12 +346,19 @@ def offset_device(label: str, prefix: str = 'starting_offset') -> str:
def starting_offsets_by_device(
- triggers: List[Trigger],
- device_types: Optional[List[str]] = None) -> Dict[str, Trigger]:
+ triggers: List[Trigger],
+ device_types: Optional[List[str]] = None
+) -> Dict[str, Trigger]:
"""Returns a dict of starting_offset triggers keyed by device type.
- If device_types are provided, an entry is created for each one, using a
- default offset of 0.0 if a match is not found.
+ Args:
+ triggers (List[Trigger]): List of triggers to search through.
+ device_types (Optional[List[str]]): List of device types to include in the result.
+ If provided, an entry is created for each one, using a default offset of 0.0
+ if a match is not found.
+
+ Returns:
+ Dict[str, Trigger]: Dictionary mapping device types to their offset triggers.
"""
offset_triggers = {}
for trg in triggers:
@@ -285,26 +376,23 @@ def starting_offsets_by_device(
return offset_triggers
-def find_starting_offset(triggers: List[Trigger],
- device_type: Optional[str] = None) -> Trigger:
+def find_starting_offset(
+ triggers: List[Trigger],
+ device_type: Optional[str] = None
+) -> Trigger:
"""Given a list of raw trigger data, determine the starting offset for the
given device. The returned trigger has the timestamp of the first sample
recorded for the device.
- If no device is provided the EEG offset will be used. If there are
- no offset triggers in the given data a Trigger with offset of 0.0 will be
- returned.
-
- Parameters
- ----------
- triggers - list of trigger data; should include Triggers of
- TriggerType.OFFSET
- device_type - each device will generally have a different offset. This
+ Args:
+ triggers (List[Trigger]): List of trigger data; should include Triggers of
+ TriggerType.OFFSET.
+ device_type (Optional[str]): Each device will generally have a different offset. This
parameter is used to determine which trigger to use. If not given
- the EEG offset will be used by default. Ex. 'EYETRACKER'
- Returns
- -------
- The Trigger for the first matching offset for the given device, or a
+ the EEG offset will be used by default. Ex. 'EYETRACKER'.
+
+ Returns:
+ Trigger: The Trigger for the first matching offset for the given device, or a
Trigger with offset of 0.0 if a matching offset was not found.
"""
label = offset_label(device_type)
@@ -318,12 +406,14 @@ def find_starting_offset(triggers: List[Trigger],
def read(path: str) -> List[Trigger]:
"""Read all Triggers from the given text file.
- Parameters
- ----------
- path - trigger (.txt) file to read
- Returns
- -------
- triggers
+ Args:
+ path (str): Trigger (.txt) file to read.
+
+ Returns:
+ List[Trigger]: List of triggers read from the file.
+
+ Raises:
+ FileNotFoundError: If the file does not exist or is not a .txt file.
"""
if not path.endswith('.txt') or not os.path.exists(path):
raise FileNotFoundError(
@@ -333,21 +423,21 @@ def read(path: str) -> List[Trigger]:
return triggers
-def apply_offsets(triggers: List[Trigger],
- starting_offset: Trigger,
- static_offset: float = 0.0) -> List[Trigger]:
+def apply_offsets(
+ triggers: List[Trigger],
+ starting_offset: Trigger,
+ static_offset: float = 0.0
+) -> List[Trigger]:
"""Returns a list of triggers with timestamps adjusted relative to the
device start time. Offset triggers are filtered out if present.
- Parameters
- ----------
- triggers - list of triggers
- starting_offset - offset from the device start time.
- static_offset - the measured static system offset
+ Args:
+ triggers (List[Trigger]): List of triggers to adjust.
+ starting_offset (Trigger): Offset from the device start time.
+ static_offset (float): The measured static system offset.
- Returns
- -------
- a list of triggers with timestamps relative to the starting_offset
+ Returns:
+ List[Trigger]: List of triggers with timestamps relative to the starting_offset.
"""
total_offset = starting_offset.time + static_offset
return [
@@ -356,28 +446,59 @@ def apply_offsets(triggers: List[Trigger],
]
-def exclude_types(triggers: List[Trigger],
- types: Optional[List[TriggerType]] = None) -> List[Trigger]:
- """Filter the list of triggers to exclude the provided types"""
+def exclude_types(
+ triggers: List[Trigger],
+ types: Optional[List[TriggerType]] = None
+) -> List[Trigger]:
+ """Filter the list of triggers to exclude the provided types.
+
+ Args:
+ triggers (List[Trigger]): List of triggers to filter.
+ types (Optional[List[TriggerType]]): List of trigger types to exclude.
+
+ Returns:
+ List[Trigger]: Filtered list of triggers.
+ """
if not types:
return triggers
return [trg for trg in triggers if trg.type not in types]
class TriggerHandler:
- """
- Class that contains methods to work with Triggers, including adding and
+ """Class that contains methods to work with Triggers, including adding and
writing triggers and loading triggers from a txt file.
+
+ Attributes:
+ encoding (str): File encoding to use.
+ path (str): Path to the trigger file.
+ file_name (str): Name of the trigger file.
+ flush (FlushFrequency): Frequency at which to flush triggers to file.
+ triggers (List[Trigger]): List of triggers being handled.
+ file_path (str): Full path to the trigger file.
+ file (TextIO): File handle for the trigger file.
"""
encoding = DEFAULT_ENCODING
- def __init__(self,
- path: str,
- file_name: str,
- flush: FlushFrequency):
+ def __init__(
+ self,
+ path: str,
+ file_name: str,
+ flush: FlushFrequency
+ ) -> None:
+ """Initialize the TriggerHandler.
+
+ Args:
+ path (str): Path to the trigger file.
+ file_name (str): Name of the trigger file.
+ flush (FlushFrequency): Frequency at which to flush triggers to file.
+
+ Raises:
+ Exception: If the file already exists.
+ """
self.path = path
- self.file_name = f'{file_name}.txt' if not file_name.endswith('.txt') else file_name
+ self.file_name = f'{file_name}.txt' if not file_name.endswith(
+ '.txt') else file_name
self.flush = flush
self.triggers: List[Trigger] = []
self.file_path = f'{self.path}/{self.file_name}'
@@ -390,33 +511,29 @@ def __init__(self,
self.file = open(self.file_path, 'w+', encoding=self.encoding)
def close(self) -> None:
- """Close.
-
- Ensures all data is written and file is closed properly.
- """
+ """Close the trigger file and ensure all data is written."""
self.write()
self.file.close()
def write(self) -> None:
- """
- Writes current Triggers in self.triggers[] to .txt file in self.file_name.
+ """Writes current Triggers in self.triggers[] to .txt file in self.file_name.
File writes in the format "label, targetness, time".
"""
-
for trigger in self.triggers:
- self.file.write(f'{trigger.label} {trigger.type.value} {trigger.time}\n')
+ self.file.write(
+ f'{trigger.label} {trigger.type.value} {trigger.time}\n')
self.triggers = []
@staticmethod
def read_text_file(path: str) -> Tuple[List[Trigger], float]:
"""Read Triggers from the given text file.
- Parameters
- ----------
- path - trigger (.txt) file to read
- Returns
- -------
- triggers, offset
+
+ Args:
+ path (str): Trigger (.txt) file to read.
+
+ Returns:
+ Tuple[List[Trigger], float]: List of triggers and offset time.
"""
triggers = read(path)
offset = find_starting_offset(triggers)
@@ -424,36 +541,26 @@ def read_text_file(path: str) -> Tuple[List[Trigger], float]:
return triggers, offset.time
@staticmethod
- def load(path: str,
- offset: float = 0.0,
- exclusion: Optional[List[TriggerType]] = None,
- device_type: Optional[str] = None) -> List[Trigger]:
- """
- Loads a list of triggers from a .txt of triggers.
-
- Exclusion based on type only (ex. exclusion=[TriggerType.Fixation])
-
- 1. Checks if .txt file exists at path
- 2. Loads the triggers data as a list of lists
- 3. If offset provided, adds it to the time as float
- 4. If exclusion provided, filters those triggers
- 5. Casts all loaded and modified triggers to Trigger
- 6. Returns as a List[Triggers]
-
- Parameters
- ----------
- path (str): name or file path of .txt trigger file to be loaded.
- Input string must include file extension (.txt).
- offset (Optional float): if desired, time offset for all loaded triggers,
- positive number for adding time, negative number for subtracting time.
- exclusion (Optional List[TriggerType]): if desired, list of TriggerType's
- to be removed from the loaded trigger list.
- device_type : optional; if specified looks for the starting_offset for
- a given device; default is to use the EEG offset.
-
- Returns
- -------
- List of Triggers from loaded .txt file with desired modifications
+ def load(
+ path: str,
+ offset: float = 0.0,
+ exclusion: Optional[List[TriggerType]] = None,
+ device_type: Optional[str] = None
+ ) -> List[Trigger]:
+ """Loads a list of triggers from a .txt of triggers.
+
+ Args:
+ path (str): Name or file path of .txt trigger file to be loaded.
+ Input string must include file extension (.txt).
+ offset (float): If desired, time offset for all loaded triggers,
+ positive number for adding time, negative number for subtracting time.
+ exclusion (Optional[List[TriggerType]]): If desired, list of TriggerType's
+ to be removed from the loaded trigger list.
+ device_type (Optional[str]): If specified looks for the starting_offset for
+ a given device; default is to use the EEG offset.
+
+ Returns:
+ List[Trigger]: List of Triggers from loaded .txt file with desired modifications.
"""
excluded_types = exclusion or []
triggers = read(path)
@@ -463,17 +570,14 @@ def load(path: str,
static_offset=offset)
def add_triggers(self, triggers: List[Trigger]) -> List[Trigger]:
- """
- Adds provided list of Triggers to self.triggers.
+ """Adds provided list of Triggers to self.triggers.
- Parameters
- ----------
- triggers (List[Triggers]): list of Trigger objects to be added to the
- handler's list of Triggers (self.triggers).
+ Args:
+ triggers (List[Trigger]): List of Trigger objects to be added to the
+ handler's list of Triggers (self.triggers).
- Returns
- -------
- Returns list of Triggers currently part of Handler
+ Returns:
+ List[Trigger]: Returns list of Triggers currently part of Handler.
"""
self.triggers.extend(triggers)
@@ -483,11 +587,22 @@ def add_triggers(self, triggers: List[Trigger]) -> List[Trigger]:
return self.triggers
-def convert_timing_triggers(timing: List[Tuple[str, float]], target_stimuli: str,
- trigger_type: Callable) -> List[Trigger]:
+def convert_timing_triggers(
+ timing: List[Tuple[str, float]],
+ target_stimuli: str,
+ trigger_type: Callable
+) -> List[Trigger]:
"""Convert Stimuli Times to Triggers.
Using the stimuli presentation times provided by the display, convert them into BciPy Triggers.
+
+ Args:
+ timing (List[Tuple[str, float]]): List of (symbol, time) tuples.
+ target_stimuli (str): Target stimulus identifier.
+ trigger_type (Callable): Function to determine trigger type.
+
+ Returns:
+ List[Trigger]: List of converted triggers.
"""
return [
Trigger(symbol, trigger_type(symbol, target_stimuli, i), time)
@@ -495,29 +610,30 @@ def convert_timing_triggers(timing: List[Tuple[str, float]], target_stimuli: str
]
-def load_triggers(trigger_path: str,
- remove_pre_fixation: bool = True,
- offset: float = 0.0,
- exclusion: Optional[List[TriggerType]] = None,
- device_type: Optional[str] = None,
- apply_starting_offset: bool = True) -> List[Trigger]:
+def load_triggers(
+ trigger_path: str,
+ remove_pre_fixation: bool = True,
+ offset: float = 0.0,
+ exclusion: Optional[List[TriggerType]] = None,
+ device_type: Optional[str] = None,
+ apply_starting_offset: bool = True
+) -> List[Trigger]:
"""Trigger Decoder.
Given a path to trigger data, this method loads valid Triggers.
- Parameters
- ----------
- trigger_path: path to triggers file
- remove_pre_fixation: boolean to determine whether any stimuli before a fixation + system should be removed
- offset: static offset value to apply to triggers.
- exclusion: any TriggerTypes to be filtered from data returned
- device_type: used to determine which starting_offset value to use; if
+ Args:
+ trigger_path (str): Path to triggers file.
+ remove_pre_fixation (bool): Boolean to determine whether any stimuli before a fixation + system should be removed.
+ offset (float): Static offset value to apply to triggers.
+ exclusion (Optional[List[TriggerType]]): Any TriggerTypes to be filtered from data returned.
+ device_type (Optional[str]): Used to determine which starting_offset value to use; if
a 'starting_offset' trigger is found it will be applied.
- apply_starting_offset: if False, does not apply the starting offset for
+ apply_starting_offset (bool): If False, does not apply the starting offset for
the given device_type.
- Returns
- -------
- list of Triggers
+
+ Returns:
+ List[Trigger]: List of processed triggers.
"""
excluded_types = exclusion or []
excluded_types += TriggerType.pre_fixation() if remove_pre_fixation else [
@@ -535,29 +651,29 @@ def load_triggers(trigger_path: str,
def trigger_decoder(
- trigger_path: str,
- remove_pre_fixation: bool = True,
- offset: float = 0.0,
- exclusion: Optional[List[TriggerType]] = None,
- device_type: Optional[str] = None,
- apply_starting_offset: bool = True) -> Tuple[list, list, list]:
+ trigger_path: str,
+ remove_pre_fixation: bool = True,
+ offset: float = 0.0,
+ exclusion: Optional[List[TriggerType]] = None,
+ device_type: Optional[str] = None,
+ apply_starting_offset: bool = True
+) -> Tuple[List[str], List[float], List[str]]:
"""Trigger Decoder.
Given a path to trigger data, this method loads valid Triggers and returns their type, timing and label.
- Parameters
- ----------
- trigger_path: path to triggers file
- remove_pre_fixation: boolean to determine whether any stimuli before a fixation + system should be removed
- offset: static offset value to apply to triggers.
- exclusion: any TriggerTypes to be filtered from data returned
- device_type: used to determine which starting_offset value to use; if
+ Args:
+ trigger_path (str): Path to triggers file.
+ remove_pre_fixation (bool): Boolean to determine whether any stimuli before a fixation + system should be removed.
+ offset (float): Static offset value to apply to triggers.
+ exclusion (Optional[List[TriggerType]]): Any TriggerTypes to be filtered from data returned.
+ device_type (Optional[str]): Used to determine which starting_offset value to use; if
a 'starting_offset' trigger is found it will be applied.
- apply_starting_offset: if False, does not apply the starting offset for
+ apply_starting_offset (bool): If False, does not apply the starting offset for
the given device_type.
- Returns
- -------
- tuple: trigger_type, trigger_timing, trigger_label
+
+ Returns:
+ Tuple[List[str], List[float], List[str]]: Tuple containing trigger types, timings, and labels.
"""
triggers = load_triggers(trigger_path, remove_pre_fixation, offset,
exclusion, device_type, apply_starting_offset)
diff --git a/bcipy/demo/bci_main_demo.py b/bcipy/demo/bci_main_demo.py
index cc51c5385..f776b7b30 100644
--- a/bcipy/demo/bci_main_demo.py
+++ b/bcipy/demo/bci_main_demo.py
@@ -1,9 +1,11 @@
from bcipy.config import DEFAULT_PARAMETERS_PATH
from bcipy.main import bci_main
-parameter_location = DEFAULT_PARAMETERS_PATH # Path to a valid BciPy parameters file
+# Path to a valid BciPy parameters file
+parameter_location = DEFAULT_PARAMETERS_PATH
user = 'test_demo_user' # User ID
-experiment_id = 'default' # This will run two tasks: RSVP Calibration and Matrix Calibration
+# This will run two tasks: RSVP Calibration and Matrix Calibration
+experiment_id = 'default'
alert = False # Set to True to alert user when tasks are complete
visualize = False # Set to True to visualize data at the end of a task
fake_data = True # Set to True to use fake acquisition data during the session
diff --git a/bcipy/display/README.md b/bcipy/display/README.md
index f09c5fea0..5df034936 100644
--- a/bcipy/display/README.md
+++ b/bcipy/display/README.md
@@ -1,22 +1,160 @@
# Display Module
-The Display module defines the visual presentation logic needed for any tasks in BciPy. Most displays take in a large number of user-configured parameters (including text size and font) and handle
- passing back timestamps from their stimuli presentations for classification.
-
-### Structure
-`display`
- `main`: Initializes a display window and contains useful display objects
- `paradigm`: top level module holding all bcipy related display objects
- `rsvp`: RSVP related display objects and functions.
- `mode`: defines task specific displays
- `matrix`: matrix related display objects and functions. Currently, only single character presentation is available.
- `mode`: defines task specific displays
- `tests`: tests for display module
- `demo`: demo code for the display module
-
-### Guidelines
-
-- Add new modes in their own submodule
-- Inherit base classes defined in display.py where possible
-- Test timing between your code and the devices you're using
- - consult psychopy (or other display codebase) for best practices for your OS
+The Display module is a core component of BciPy that handles all visual presentation logic for BCI tasks. It provides a flexible and configurable framework for creating and managing visual stimuli, with precise timing control and synchronization capabilities. Please note that the VEPDisplay is still in progress. Please use with discretion!
+
+## Structure
+
+The module is organized into several key components:
+
+- `main/`: Core display initialization and window management
+- `paradigm/`: BCI paradigm-specific display implementations
+ - `rsvp/`: RSVP Keyboard display components
+ - `matrix/`: Matrix Speller display components
+ - `vep/`: *WIP* Visual Evoked Potetinal display components.
+- `components/`: Reusable display components
+- `tests/`: Unit and integration tests
+- `demo/`: Example implementations and usage
+
+## Core Concepts
+
+### Display System
+
+The display system is built on several key abstractions:
+
+#### 1. Display Base Class
+
+The base `Display` class provides core functionality for all displays. It is defined in `main.py`.
+
+- Window management and initialization
+- Stimulus presentation timing
+- Task bar and information display
+- Trigger handling and calibration
+
+#### 2. Stimuli Properties
+
+Configure visual properties of stimuli:
+
+```python
+from bcipy.display import StimuliProperties
+
+properties = StimuliProperties(
+ stim_font='Arial',
+ stim_height=32,
+ stim_pos=(0, 0)
+)
+```
+
+#### 3. Information Properties
+
+Manage task information display:
+
+```python
+from bcipy.display import InformationProperties
+
+info = InformationProperties(
+ info_color='white',
+ info_text="Task Progress",
+ info_font='Consolas',
+ info_height=24,
+ info_pos=(0, 0)
+)
+```
+
+### Key Features
+
+#### Window Management
+
+- PsychoPy window initialization
+- Screen resolution handling
+- Fullscreen/windowed modes
+- Refresh rate synchronization
+
+#### Stimulus Presentation
+
+- Precise timing control
+- Animation and transitions
+- Trigger synchronization
+- Event logging
+
+#### Task Display
+
+- Progress tracking
+- User feedback
+- Custom UI elements
+- Dynamic updates
+
+#### Layout System
+
+- Flexible positioning
+- Responsive design
+- Component alignment
+- Screen space optimization
+
+## Supported Paradigms
+
+### RSVP Keyboard
+
+The RSVP Keyboard is an EEG-based typing system that presents symbols sequentially at a single location. Users select symbols by attending to their target and eliciting a P300 response.
+
+Key features:
+
+- Single-location presentation
+- Temporal separation of stimuli
+- P300-based selection
+- Configurable timing
+
+### Matrix Speller
+
+The Matrix Speller presents symbols in a grid layout, highlighting subsets of symbols to elicit P300 responses for selection.
+
+Key features:
+
+- Grid-based layout
+- P300-based selection
+- Configurable matrix size
+
+## Development Guidelines
+
+1. **Adding New Paradigms**
+ - Create a new submodule in `paradigm/`
+ - Inherit from base `Display` class
+ - Implement required interface methods
+ - Add comprehensive tests
+
+2. **Timing Considerations**
+ - Test timing with actual hardware
+ - Account for refresh rate variations
+ - Log timing events for analysis
+ - Use PsychoPy's timing functions
+
+3. **Best Practices**
+ - Follow PsychoPy guidelines
+ - Document timing parameters
+ - Include example configurations
+ - Add unit tests for timing
+
+## References
+
+1. RSVP Keyboard:
+
+```text
+Orhan, U., et al. (2012). RSVP Keyboard: An EEG Based Typing Interface.
+IEEE International Conference on Acoustics, Speech, and Signal Processing.
+```
+
+1. Matrix Speller:
+
+```text
+Farwell, L. A., & Donchin, E. (1988). Talking off the top of your head:
+toward a mental prosthesis utilizing event-related brain potentials.
+Electroencephalography and clinical Neurophysiology, 70(6), 510-523.
+```
+
+## Support
+
+For issues, questions, or contributions:
+
+- Open an issue on GitHub
+- Check existing documentation
+- Review test examples
+- Consult the demo implementations
diff --git a/bcipy/display/components/button_press_handler.py b/bcipy/display/components/button_press_handler.py
index 3a087f74d..42563a362 100644
--- a/bcipy/display/components/button_press_handler.py
+++ b/bcipy/display/components/button_press_handler.py
@@ -1,6 +1,11 @@
-"""Handles button press interactions"""
+"""Handles button press interactions.
+
+This module provides classes for handling button press events in the BciPy system.
+It includes abstract base classes and concrete implementations for different types
+of button press handling strategies.
+"""
from abc import ABC, abstractmethod
-from typing import List, Optional, Type
+from typing import Any, List, Optional, Type
from psychopy import event
from psychopy.core import CountdownTimer
@@ -10,100 +15,178 @@
class ButtonPressHandler(ABC):
- """Handles button press events."""
-
- def __init__(self,
- max_wait: float,
- key_input: str,
- clock: Optional[Clock] = None,
- timer: Type[CountdownTimer] = CountdownTimer):
- """
- Parameters
- ----------
- wait_length - maximum number of seconds to wait for a key press
- key_input - key that we are listening for.
- clock - clock used to associate the key event with a timestamp
+ """Handles button press events.
+
+ This is an abstract base class that defines the interface for handling button press
+ events. It provides functionality for waiting for and processing button presses
+ within a specified time window.
+
+ Attributes:
+ max_wait (float): Maximum number of seconds to wait for a key press.
+ key_input (str): Key that we are listening for.
+ clock (Clock): Clock used to associate the key event with a timestamp.
+ response (Optional[List]): List containing the latest response information.
+ _timer (Optional[CountdownTimer]): Timer for tracking wait period.
+ make_timer (Type[CountdownTimer]): Factory for creating timer instances.
+ """
+
+ def __init__(
+ self,
+ max_wait: float,
+ key_input: str,
+ clock: Optional[Clock] = None,
+ timer: Type[CountdownTimer] = CountdownTimer
+ ) -> None:
+ """Initialize the ButtonPressHandler.
+
+ Args:
+ max_wait (float): Maximum number of seconds to wait for a key press.
+ key_input (str): Key that we are listening for.
+ clock (Optional[Clock]): Clock used to associate the key event with a timestamp.
+ If None, a new Clock instance will be created.
+ timer (Type[CountdownTimer]): Factory for creating timer instances.
+ Defaults to CountdownTimer.
"""
self.max_wait = max_wait
self.key_input = key_input
self.clock = clock or Clock()
- self.response: Optional[List] = None
+ self.response: Optional[List[Any]] = None
self._timer: Optional[CountdownTimer] = None
self.make_timer = timer
@property
def response_label(self) -> Optional[str]:
- """Label for the latest button press"""
+ """Get the label for the latest button press.
+
+ Returns:
+ Optional[str]: The label of the latest button press, or None if no response.
+ """
return self.response[0] if self.response else None
@property
def response_timestamp(self) -> Optional[float]:
- """Timestamp for the latest button response"""
+ """Get the timestamp for the latest button response.
+
+ Returns:
+ Optional[float]: The timestamp of the latest button press, or None if no response.
+ """
return self.response[1] if self.response else None
def _reset(self) -> None:
- """Reset any existing events and timers."""
+ """Reset any existing events and timers.
+
+ This method clears any existing keyboard events, resets the timer,
+ and clears the current response.
+ """
self._timer = self.make_timer(self.max_wait)
self.response = None
event.clearEvents(eventType='keyboard')
self._timer.reset()
def await_response(self) -> None:
- """Wait for a button response for a maximum number of seconds. Wait
- period could end early if the class determines that some other
- criteria have been met (such as an acceptable response)."""
+ """Wait for a button response for a maximum number of seconds.
+ Wait period could end early if the class determines that some other
+ criteria have been met (such as an acceptable response).
+ """
self._reset()
while self._should_keep_waiting() and self._within_wait_period():
self._check_key_press()
def has_response(self) -> bool:
- """Whether a response has been provided"""
+ """Check whether a response has been provided.
+
+ Returns:
+ bool: True if a response has been provided, False otherwise.
+ """
return self.response is not None
def _check_key_press(self) -> None:
- """Check for any key press events and set the latest as the response."""
+ """Check for any key press events and set the latest as the response.
+
+ This method updates the response attribute with the latest key press
+ information if a valid key press is detected.
+ """
self.response = get_key_press(
key_list=[self.key_input],
clock=self.clock,
)
def _within_wait_period(self) -> bool:
- """Check that we are within the allotted time for a response."""
+ """Check that we are within the allotted time for a response.
+
+ Returns:
+ bool: True if we are still within the wait period, False otherwise.
+ """
return (self._timer is not None) and (self._timer.getTime() > 0)
def _should_keep_waiting(self) -> bool:
- """Check that we should keep waiting for responses."""
+ """Check that we should keep waiting for responses.
+
+ Returns:
+ bool: True if we should continue waiting, False otherwise.
+ """
return not self.has_response()
@abstractmethod
def accept_result(self) -> bool:
- """Should the result of a button press be affirmative"""
+ """Determine if the result of a button press should be affirmative.
+
+ Returns:
+ bool: True if the result should be considered affirmative, False otherwise.
+ """
+ pass
class AcceptButtonPressHandler(ButtonPressHandler):
- """ButtonPressHandler where a matching button press indicates an affirmative result."""
+ """ButtonPressHandler where a matching button press indicates an affirmative result.
+
+ This handler considers a button press as an affirmative response.
+ """
def accept_result(self) -> bool:
- """Should the result of a button press be affirmative"""
+ """Determine if the result of a button press should be affirmative.
+
+ Returns:
+ bool: True if a response has been provided, False otherwise.
+ """
return self.has_response()
class RejectButtonPressHandler(ButtonPressHandler):
- """ButtonPressHandler where a matching button press indicates a rejection."""
+ """ButtonPressHandler where a matching button press indicates a rejection.
+
+ This handler considers a button press as a rejection response.
+ """
def accept_result(self) -> bool:
- """Should the result of a button press be affirmative"""
+ """Determine if the result of a button press should be affirmative.
+
+ Returns:
+ bool: True if no response has been provided, False otherwise.
+ """
return not self.has_response()
class PreviewOnlyButtonPressHandler(ButtonPressHandler):
- """ButtonPressHandler that waits for the entire span of the configured max_wait."""
+ """ButtonPressHandler that waits for the entire span of the configured max_wait.
+
+ This handler always waits for the full duration regardless of button presses.
+ """
def _should_keep_waiting(self) -> bool:
+ """Check that we should keep waiting for responses.
+
+ Returns:
+ bool: Always returns True to ensure full wait duration.
+ """
return True
def accept_result(self) -> bool:
- """Should the result of a button press be affirmative"""
+ """Determine if the result of a button press should be affirmative.
+
+ Returns:
+ bool: Always returns True.
+ """
return True
diff --git a/bcipy/display/components/layout.py b/bcipy/display/components/layout.py
index e0887a79f..0c6fa86cb 100644
--- a/bcipy/display/components/layout.py
+++ b/bcipy/display/components/layout.py
@@ -1,11 +1,21 @@
# mypy: disable-error-code="override"
-"""Defines common functionality for GUI layouts."""
+"""Defines common functionality for GUI layouts.
+
+This module provides classes and functions for managing layout and positioning
+of GUI elements in the BciPy system. It includes utilities for alignment,
+scaling, and positioning of components within containers.
+"""
from enum import Enum
from typing import List, Optional, Protocol, Tuple
class Container(Protocol):
- """Protocol for an enclosing container with units and size."""
+ """Protocol for an enclosing container with units and size.
+
+ Attributes:
+ size (Tuple[float, float]): Size of the container as (width, height).
+ units (str): Units used for measurements (e.g., 'norm', 'height').
+ """
size: Tuple[float, float]
units: str
@@ -18,7 +28,11 @@ class Container(Protocol):
class Alignment(Enum):
- """Specifies how elements should be aligned spatially"""
+ """Specifies how elements should be aligned spatially.
+
+ This enum defines the possible alignment options for positioning elements
+ within a container.
+ """
CENTERED = 1
LEFT = 2
RIGHT = 3
@@ -26,49 +40,104 @@ class Alignment(Enum):
BOTTOM = 5
@classmethod
- def horizontal(cls):
- """Subset used for horizontal alignment"""
+ def horizontal(cls) -> List['Alignment']:
+ """Get subset used for horizontal alignment.
+
+ Returns:
+ List[Alignment]: List of horizontal alignment options.
+ """
return [Alignment.CENTERED, Alignment.LEFT, Alignment.RIGHT]
@classmethod
- def vertical(cls):
- """Subset used for vertical alignment"""
+ def vertical(cls) -> List['Alignment']:
+ """Get subset used for vertical alignment.
+
+ Returns:
+ List[Alignment]: List of vertical alignment options.
+ """
return [Alignment.CENTERED, Alignment.TOP, Alignment.BOTTOM]
# Positioning functions
def above(y_coordinate: float, amount: float) -> float:
- """Returns a new y_coordinate value that is above the provided value
- by the given amount."""
+ """Returns a new y_coordinate value that is above the provided value.
+
+ Args:
+ y_coordinate (float): Base y-coordinate.
+ amount (float): Distance to move upward.
+
+ Returns:
+ float: New y-coordinate value.
+
+ Raises:
+ AssertionError: If amount is negative.
+ """
assert amount >= 0, 'Amount must be positive'
return y_coordinate + amount
def below(y_coordinate: float, amount: float) -> float:
- """Returns a new y_coordinate value that is below the provided value
- by the given amount."""
+ """Returns a new y_coordinate value that is below the provided value.
+
+ Args:
+ y_coordinate (float): Base y-coordinate.
+ amount (float): Distance to move downward.
+
+ Returns:
+ float: New y-coordinate value.
+
+ Raises:
+ AssertionError: If amount is negative.
+ """
assert amount >= 0, 'Amount must be positive'
return y_coordinate - amount
def right_of(x_coordinate: float, amount: float) -> float:
- """Returns a new x_coordinate value that is to the right of the
- provided value by the given amount."""
+ """Returns a new x_coordinate value that is to the right of the provided value.
+
+ Args:
+ x_coordinate (float): Base x-coordinate.
+ amount (float): Distance to move right.
+
+ Returns:
+ float: New x-coordinate value.
+
+ Raises:
+ AssertionError: If amount is negative.
+ """
assert amount >= 0, 'Amount must be positive'
return x_coordinate + amount
def left_of(x_coordinate: float, amount: float) -> float:
- """Returns a new x_coordinate value that is to the left of the
- provided value by the given amount."""
+ """Returns a new x_coordinate value that is to the left of the provided value.
+
+ Args:
+ x_coordinate (float): Base x-coordinate.
+ amount (float): Distance to move left.
+
+ Returns:
+ float: New x-coordinate value.
+
+ Raises:
+ AssertionError: If amount is negative.
+ """
assert amount >= 0, 'Amount must be positive'
return x_coordinate - amount
def envelope(pos: Tuple[float, float],
size: Tuple[float, float]) -> List[Tuple[float, float]]:
- """Compute the vertices for the envelope of a shape centered at pos with
- the given size."""
+ """Compute the vertices for the envelope of a shape.
+
+ Args:
+ pos (Tuple[float, float]): Center position of the shape.
+ size (Tuple[float, float]): Size of the shape as (width, height).
+
+ Returns:
+ List[Tuple[float, float]]: List of vertices defining the shape's envelope.
+ """
width, height = size
half_w = width / 2
half_h = height / 2
@@ -81,8 +150,16 @@ def envelope(pos: Tuple[float, float],
def scaled_size(height: float,
window_size: Tuple[float, float],
units: str = 'norm') -> Tuple[float, float]:
- """Scales the provided height value to reflect the aspect ratio of a
- visual.Window. Used for creating squared stimulus. Returns (w,h) tuple"""
+ """Scales the provided height value to reflect the aspect ratio of a window.
+
+ Args:
+ height (float): Height value to scale.
+ window_size (Tuple[float, float]): Window dimensions as (width, height).
+ units (str): Units to use for scaling. Defaults to 'norm'.
+
+ Returns:
+ Tuple[float, float]: Scaled size as (width, height).
+ """
if units == 'height':
width = height
return (width, height)
@@ -95,8 +172,16 @@ def scaled_size(height: float,
def scaled_height(width: float,
window_size: Tuple[float, float],
units: str = 'norm') -> float:
- """Given a width, find the equivalent height scaled to the aspect ratio of
- a window with the given size"""
+ """Given a width, find the equivalent height scaled to the aspect ratio.
+
+ Args:
+ width (float): Width value to scale.
+ window_size (Tuple[float, float]): Window dimensions as (width, height).
+ units (str): Units to use for scaling. Defaults to 'norm'.
+
+ Returns:
+ float: Scaled height value.
+ """
if units == 'height':
return width
win_width, win_height = window_size
@@ -105,24 +190,56 @@ def scaled_height(width: float,
def scaled_width(height: float,
window_size: Tuple[float, float],
- units: str = 'norm'):
- """Given a height, find the equivalent width scaled to the aspect ratio of
- a window with the given size"""
+ units: str = 'norm') -> float:
+ """Given a height, find the equivalent width scaled to the aspect ratio.
+
+ Args:
+ height (float): Height value to scale.
+ window_size (Tuple[float, float]): Window dimensions as (width, height).
+ units (str): Units to use for scaling. Defaults to 'norm'.
+
+ Returns:
+ float: Scaled width value.
+ """
width, _height = scaled_size(height, window_size, units)
return width
class Layout(Container):
"""Class with methods for positioning elements within a parent container.
+
+ This class provides functionality for managing the layout and positioning
+ of GUI elements within a container, including methods for resizing and
+ alignment.
+
+ Attributes:
+ units (str): Units used for measurements (e.g., 'norm', 'height').
+ parent (Optional[Container]): Parent container if any.
+ top (float): Top boundary position.
+ left (float): Left boundary position.
+ bottom (float): Bottom boundary position.
+ right (float): Right boundary position.
"""
- def __init__(self,
- parent: Optional[Container] = None,
- left: float = DEFAULT_LEFT,
- top: float = DEFAULT_TOP,
- right: float = DEFAULT_RIGHT,
- bottom: float = DEFAULT_BOTTOM,
- units: str = "norm"):
+ def __init__(
+ self,
+ parent: Optional[Container] = None,
+ left: float = DEFAULT_LEFT,
+ top: float = DEFAULT_TOP,
+ right: float = DEFAULT_RIGHT,
+ bottom: float = DEFAULT_BOTTOM,
+ units: str = "norm"
+ ) -> None:
+ """Initialize the Layout.
+
+ Args:
+ parent (Optional[Container]): Parent container. Defaults to None.
+ left (float): Left boundary position. Defaults to DEFAULT_LEFT.
+ top (float): Top boundary position. Defaults to DEFAULT_TOP.
+ right (float): Right boundary position. Defaults to DEFAULT_RIGHT.
+ bottom (float): Bottom boundary position. Defaults to DEFAULT_BOTTOM.
+ units (str): Units to use. Defaults to "norm".
+ """
self.units: str = units
self.parent = parent
self.top = top
@@ -131,8 +248,12 @@ def __init__(self,
self.right = right
self.check_invariants()
- def check_invariants(self):
- """Check that all invariants hold true."""
+ def check_invariants(self) -> None:
+ """Check that all invariants hold true.
+
+ Raises:
+ AssertionError: If any invariant is violated.
+ """
# https://psychopy.org/general/units.html#units
assert self.units in ['height',
'norm'], "Units must be 'height' or 'norm'"
@@ -162,8 +283,17 @@ def check_invariants(self):
1], "Height must be greater than 0 and fit within the parent height."
def scaled_size(self, height: float) -> Tuple[float, float]:
- """Returns the (w,h) value scaled to reflect the aspect ratio of a
- visual.Window. Used for creating squared stimulus"""
+ """Returns the (w,h) value scaled to reflect the aspect ratio.
+
+ Args:
+ height (float): Height value to scale.
+
+ Returns:
+ Tuple[float, float]: Scaled size as (width, height).
+
+ Raises:
+ AssertionError: If parent is not configured.
+ """
if self.units == 'height':
width = height
return (width, height)
@@ -172,52 +302,92 @@ def scaled_size(self, height: float) -> Tuple[float, float]:
@property
def size(self) -> Tuple[float, float]:
- """Layout size."""
+ """Get the layout size.
+
+ Returns:
+ Tuple[float, float]: Size as (width, height).
+ """
return (self.width, self.height)
@property
def width(self) -> float:
- """Width in norm units of this component."""
+ """Get the width in norm units of this component.
+
+ Returns:
+ float: Width value.
+ """
return self.right - self.left
@property
def height(self) -> float:
- """Height in norm units of this component."""
+ """Get the height in norm units of this component.
+
+ Returns:
+ float: Height value.
+ """
return self.top - self.bottom
@property
def left_top(self) -> Tuple[float, float]:
- """Top left position"""
+ """Get the top left position.
+
+ Returns:
+ Tuple[float, float]: Position as (x, y).
+ """
return (self.left, self.top)
@property
def right_bottom(self) -> Tuple[float, float]:
- """Bottom right position"""
+ """Get the bottom right position.
+
+ Returns:
+ Tuple[float, float]: Position as (x, y).
+ """
return (self.right, self.bottom)
@property
def horizontal_middle(self) -> float:
- """x-axis value in norm units for the midpoint of this component"""
+ """Get the x-axis value for the midpoint of this component.
+
+ Returns:
+ float: X-coordinate of the midpoint.
+ """
return (self.left + self.right) / 2
@property
def vertical_middle(self) -> float:
- """x-axis value in norm units for the midpoint of this component."""
+ """Get the y-axis value for the midpoint of this component.
+
+ Returns:
+ float: Y-coordinate of the midpoint.
+ """
return (self.top + self.bottom) / 2
@property
def center(self) -> Tuple[float, float]:
- """Center point of the component in norm units. Returns a (x,y) tuple."""
+ """Get the center point of the component.
+
+ Returns:
+ Tuple[float, float]: Center position as (x, y).
+ """
return (self.horizontal_middle, self.vertical_middle)
@property
def left_middle(self) -> Tuple[float, float]:
- """Point centered on the left-most edge."""
+ """Get the point centered on the left-most edge.
+
+ Returns:
+ Tuple[float, float]: Position as (x, y).
+ """
return (self.left, self.vertical_middle)
@property
def right_middle(self) -> Tuple[float, float]:
- """Point centered on the right-most edge."""
+ """Get the point centered on the right-most edge.
+
+ Returns:
+ Tuple[float, float]: Position as (x, y).
+ """
return (self.right, self.vertical_middle)
def resize_width(self,
@@ -225,10 +395,13 @@ def resize_width(self,
alignment: Alignment = Alignment.CENTERED) -> None:
"""Adjust the width of the current layout.
- Parameters
- ----------
- width_pct - percentage of the current width
- alignment - specifies how the remaining width should be aligned.
+ Args:
+ width_pct (float): Percentage of the current width.
+ alignment (Alignment): Specifies how the remaining width should be aligned.
+ Defaults to Alignment.CENTERED.
+
+ Raises:
+ AssertionError: If width_pct is not positive or alignment is invalid.
"""
assert 0 < width_pct, 'width_pct must be greater than 0'
assert alignment in Alignment.horizontal()
@@ -256,10 +429,13 @@ def resize_height(self,
alignment: Alignment = Alignment.CENTERED) -> None:
"""Adjust the height of the current layout.
- Parameters
- ----------
- height_pct - percentage of the current width
- alignment - specifies how the remaining width should be aligned.
+ Args:
+ height_pct (float): Percentage of the current height.
+ alignment (Alignment): Specifies how the remaining height should be aligned.
+ Defaults to Alignment.CENTERED.
+
+ Raises:
+ AssertionError: If height_pct is not positive or alignment is invalid.
"""
assert 0 < height_pct, 'height_pct must be greater than 0'
assert alignment in Alignment.vertical()
@@ -286,12 +462,14 @@ def resize_height(self,
# Factory functions
def at_top(parent: Container, height: float) -> Layout:
- """Constructs a layout of a given height that spans the full width of the
- window and is positioned at the top.
+ """Constructs a layout of a given height that spans the full width of the window.
- Parameters
- ----------
- height - value in 'norm' units
+ Args:
+ parent (Container): Parent container.
+ height (float): Height value in 'norm' units.
+
+ Returns:
+ Layout: New layout instance positioned at the top.
"""
top = DEFAULT_TOP
return Layout(parent=parent,
@@ -302,8 +480,15 @@ def at_top(parent: Container, height: float) -> Layout:
def at_bottom(parent: Container, height: float) -> Layout:
- """Constructs a layout of a given height that spans the full width of the
- window and is positioned at the bottom"""
+ """Constructs a layout of a given height that spans the full width of the window.
+
+ Args:
+ parent (Container): Parent container.
+ height (float): Height value in 'norm' units.
+
+ Returns:
+ Layout: New layout instance positioned at the bottom.
+ """
bottom = DEFAULT_BOTTOM
return Layout(parent=parent,
left=DEFAULT_LEFT,
@@ -315,17 +500,17 @@ def at_bottom(parent: Container, height: float) -> Layout:
def centered(parent: Optional[Container] = None,
width_pct: float = 1.0,
height_pct: float = 1.0) -> Layout:
- """Constructs a layout that is centered on the screen. Default size is
- fullscreen but optional parameters can be used to adjust the width and
- height.
-
- Parameters
- ----------
- parent - optional parent
- width_pct - optional; sets the width to a given percentage of
- fullscreen.
- height_pct - optional; sets the height to a given percentage of
- fullscreen.
+ """Constructs a layout that is centered on the screen.
+
+ Args:
+ parent (Optional[Container]): Optional parent container.
+ width_pct (float): Optional; sets the width to a given percentage of fullscreen.
+ Defaults to 1.0.
+ height_pct (float): Optional; sets the height to a given percentage of fullscreen.
+ Defaults to 1.0.
+
+ Returns:
+ Layout: New centered layout instance.
"""
container = Layout(parent=parent)
container.resize_width(width_pct, alignment=Alignment.CENTERED)
@@ -334,8 +519,14 @@ def centered(parent: Optional[Container] = None,
def from_envelope(verts: List[Tuple[float, float]]) -> Layout:
- """Constructs a layout from a list of vertices which comprise a shape's
- envelope."""
+ """Constructs a layout from a list of vertices which comprise a shape's envelope.
+
+ Args:
+ verts (List[Tuple[float, float]]): List of vertices defining the shape's envelope.
+
+ Returns:
+ Layout: New layout instance based on the envelope.
+ """
x_coords, y_coords = zip(*verts)
return Layout(left=min(x_coords),
top=max(y_coords),
@@ -344,13 +535,19 @@ def from_envelope(verts: List[Tuple[float, float]]) -> Layout:
def height_units(window_size: Tuple[float, float]) -> Layout:
- """Constructs a layout with height units using the given Window
- dimensions
+ """Constructs a layout with height units using the given Window dimensions.
+
+ Args:
+ window_size (Tuple[float, float]): Window dimensions as (width, height).
+
+ Returns:
+ Layout: New layout instance using height units.
- for an aspect ratio of 4:3
- 4 widths / 3 height = 1.333
- 1.333 / 2 = 0.667
- so, left is -0.667 and right is 0.667
+ Note:
+ For an aspect ratio of 4:3:
+ 4 widths / 3 height = 1.333
+ 1.333 / 2 = 0.667
+ so, left is -0.667 and right is 0.667
"""
win_width, win_height = window_size
right = (win_width / win_height) / 2
diff --git a/bcipy/display/components/task_bar.py b/bcipy/display/components/task_bar.py
index 654acd859..4ade5574f 100644
--- a/bcipy/display/components/task_bar.py
+++ b/bcipy/display/components/task_bar.py
@@ -1,6 +1,11 @@
-"""Task bar component"""
+"""Task bar component.
-from typing import Dict, List, Optional
+This module provides components for displaying task-related information in a task window.
+It includes base task bar functionality and specialized implementations for different
+types of tasks like calibration and copy phrase tasks.
+"""
+
+from typing import Any, Dict, List, Optional
from psychopy import visual
from psychopy.visual.basevisual import BaseVisualStim
@@ -9,27 +14,43 @@
class TaskBar:
- """Component for displaying task-related information in a task window. The
- component elements are positioned at the top of the window.
-
- Parameters
- ----------
- win - visual.Window on which to render elements
- colors - Ordered list of colors to apply to task stimuli
- font - Font to apply to all task stimuli
- height - Height of all task text stimuli
- text - Task text to apply to stimuli
- padding - used in conjunction with the text height to compute the
- overall height of the task bar.
+ """Component for displaying task-related information in a task window.
+
+ The component elements are positioned at the top of the window.
+
+ Attributes:
+ win (visual.Window): Window on which to render elements.
+ colors (List[str]): Ordered list of colors to apply to task stimuli.
+ font (str): Font to apply to all task stimuli.
+ height (float): Height of all task text stimuli.
+ padding (float): Padding used in conjunction with text height to compute
+ the overall height of the task bar.
+ text (str): Task text to apply to stimuli.
+ layout (layout.Layout): Layout manager for positioning elements.
+ stim (Dict[str, BaseVisualStim]): Dictionary of visual stimuli.
"""
- def __init__(self,
- win: visual.Window,
- colors: Optional[List[str]] = None,
- font: str = 'Courier New',
- height: float = 0.1,
- text: str = '',
- padding: Optional[float] = None):
+ def __init__(
+ self,
+ win: visual.Window,
+ colors: Optional[List[str]] = None,
+ font: str = 'Courier New',
+ height: float = 0.1,
+ text: str = '',
+ padding: Optional[float] = None
+ ) -> None:
+ """Initialize the TaskBar.
+
+ Args:
+ win (visual.Window): Window on which to render elements.
+ colors (Optional[List[str]]): Ordered list of colors to apply to task stimuli.
+ Defaults to ['white'].
+ font (str): Font to apply to all task stimuli. Defaults to 'Courier New'.
+ height (float): Height of all task text stimuli. Defaults to 0.1.
+ text (str): Task text to apply to stimuli. Defaults to ''.
+ padding (Optional[float]): Padding used in conjunction with text height.
+ If None, defaults to height/2.
+ """
self.win = win
self.colors = colors or ['white']
self.font = font
@@ -41,35 +62,50 @@ def __init__(self,
@property
def height_pct(self) -> float:
- """Percentage of the total window that the task bar occupies.
+ """Get the percentage of the total window that the task bar occupies.
- Returns
- -------
- percentage ; value will be between 0 and 1.
+ Returns:
+ float: Percentage value between 0 and 1.
"""
win_layout = layout.Layout(self.win)
return self.compute_height() / win_layout.height
- def compute_height(self):
- """Computes the component height using the provided config."""
+ def compute_height(self) -> float:
+ """Compute the component height using the provided config.
+
+ Returns:
+ float: Total height of the task bar.
+ """
return self.height + self.padding
def init_stim(self) -> Dict[str, BaseVisualStim]:
- """Initialize the stimuli elements."""
+ """Initialize the stimuli elements.
+
+ Returns:
+ Dict[str, BaseVisualStim]: Dictionary of initialized visual stimuli.
+ """
task = self.text_stim()
return {'task_text': task, 'border': self.border_stim()}
- def draw(self):
+ def draw(self) -> None:
"""Draw the task bar elements."""
for _key, stim in self.stim.items():
stim.draw()
- def update(self, text: str = ''):
- """Update the task bar to display the given text."""
+ def update(self, text: str = '') -> None:
+ """Update the task bar to display the given text.
+
+ Args:
+ text (str): New text to display. Defaults to ''.
+ """
self.stim['task_text'].text = text
def border_stim(self) -> visual.Line:
- """Create the task bar outline"""
+ """Create the task bar outline.
+
+ Returns:
+ visual.Line: Line stimulus representing the task bar border.
+ """
# pylint: disable=not-callable
return visual.Line(
win=self.win,
@@ -79,15 +115,26 @@ def border_stim(self) -> visual.Line:
lineColor=self.colors[0])
def text_stim(self, **kwargs) -> visual.TextStim:
- """Constructs a TextStim. Uses the config to set default properties
- which may be overridden by providing keyword args.
- """
+ """Construct a TextStim with default properties.
+
+ Uses the config to set default properties which may be overridden by
+ providing keyword args.
+
+ Args:
+ **kwargs: Additional properties to override defaults.
+ Returns:
+ visual.TextStim: Configured text stimulus.
+ """
props = {**self.default_text_props(), **kwargs}
return visual.TextStim(**props)
- def default_text_props(self) -> dict:
- """Default properties for constructing a TextStim."""
+ def default_text_props(self) -> Dict[str, Any]:
+ """Get default properties for constructing a TextStim.
+
+ Returns:
+ Dict[str, Any]: Dictionary of default text stimulus properties.
+ """
return {
'win': self.win,
'text': self.text,
@@ -100,82 +147,125 @@ def default_text_props(self) -> dict:
class CalibrationTaskBar(TaskBar):
- """Task bar for Calibration tasks. Displays the count of inquiries.
-
- Parameters
- ----------
- win - visual.Window on which to render elements
- inquiry_count - total number of inquiries to display
- current_index - index of the current inquiry
- **config - display config (colors, font, height)
+ """Task bar for Calibration tasks.
+
+ Displays the count of inquiries in the format "current/total".
+
+ Attributes:
+ inquiry_count (int): Total number of inquiries to display.
+ current_index (int): Index of the current inquiry.
"""
- def __init__(self,
- win: visual.Window,
- inquiry_count: int,
- current_index: int = 1,
- **config):
+ def __init__(
+ self,
+ win: visual.Window,
+ inquiry_count: int,
+ current_index: int = 1,
+ **config: Any
+ ) -> None:
+ """Initialize the CalibrationTaskBar.
+
+ Args:
+ win (visual.Window): Window on which to render elements.
+ inquiry_count (int): Total number of inquiries to display.
+ current_index (int): Index of the current inquiry. Defaults to 1.
+ **config: Additional display configuration (colors, font, height).
+ """
self.inquiry_count = inquiry_count
self.current_index = current_index
super().__init__(win, **config)
def init_stim(self) -> Dict[str, BaseVisualStim]:
- """Initialize the stimuli elements."""
+ """Initialize the stimuli elements.
- task = self.text_stim(text=self.displayed_text(),
- pos=self.layout.left_middle,
- anchorHoriz='left',
- alignText='left')
+ Returns:
+ Dict[str, BaseVisualStim]: Dictionary of initialized visual stimuli.
+ """
+ task = self.text_stim(
+ text=self.displayed_text(),
+ pos=self.layout.left_middle,
+ anchorHoriz='left',
+ alignText='left'
+ )
return {'task_text': task, 'border': self.border_stim()}
- def update(self, text: str = ''):
- """Update the displayed text"""
+ def update(self, text: str = '') -> None:
+ """Update the displayed text.
+
+ Args:
+ text (str): Unused parameter. Defaults to ''.
+ """
self.current_index += 1
self.stim['task_text'].text = self.displayed_text()
def displayed_text(self) -> str:
- """Text to display. Computed from the current_index and inquiry_count. Ex. '2/100'"""
+ """Get the text to display.
+
+ Returns:
+ str: Text in the format "current/total" (e.g., "2/100").
+ """
return f" {self.current_index}/{self.inquiry_count}"
class CopyPhraseTaskBar(TaskBar):
- """Task bar for the Copy Phrase Task
-
- Parameters
- ----------
- win - visual.Window on which to render elements
- task_text - text for the participant to spell
- spelled_text - text that has already been spelled
- **config - display config (colors, font, height)
+ """Task bar for the Copy Phrase Task.
+
+ Displays both the target text and the currently spelled text.
+
+ Attributes:
+ task_text (str): Text for the participant to spell.
+ spelled_text (str): Text that has already been spelled.
"""
- def __init__(self,
- win: visual.Window,
- task_text: str = '',
- spelled_text: str = '',
- **config):
+ def __init__(
+ self,
+ win: visual.Window,
+ task_text: str = '',
+ spelled_text: str = '',
+ **config: Any
+ ) -> None:
+ """Initialize the CopyPhraseTaskBar.
+
+ Args:
+ win (visual.Window): Window on which to render elements.
+ task_text (str): Text for the participant to spell. Defaults to ''.
+ spelled_text (str): Text that has already been spelled. Defaults to ''.
+ **config: Additional display configuration (colors, font, height).
+ """
self.task_text = task_text
self.spelled_text = spelled_text
super().__init__(win, **config)
- def compute_height(self):
- """Computes the component height using the provided config."""
- # height is doubled to account for task_text and spelled_text being on
- # separate lines.
+ def compute_height(self) -> float:
+ """Compute the component height using the provided config.
+
+ Height is doubled to account for task_text and spelled_text being on
+ separate lines.
+
+ Returns:
+ float: Total height of the task bar.
+ """
return (self.height * 2) + self.padding
def init_stim(self) -> Dict[str, BaseVisualStim]:
- """Initialize the stimuli elements."""
-
- task = self.text_stim(text=self.task_text,
- pos=self.layout.center,
- anchorVert='bottom')
+ """Initialize the stimuli elements.
- spelled = self.text_stim(text=self.displayed_text(),
- pos=self.layout.center,
- color=self.colors[-1],
- anchorVert='top')
+ Returns:
+ Dict[str, BaseVisualStim]: Dictionary of initialized visual stimuli.
+ """
+ task = self.text_stim(
+ text=self.task_text,
+ pos=self.layout.center,
+ anchorVert='bottom'
+ )
+
+ spelled = self.text_stim(
+ text=self.displayed_text(),
+ pos=self.layout.center,
+ color=self.colors[-1],
+ anchorVert='top'
+ )
return {
'task_text': task,
@@ -183,13 +273,21 @@ def init_stim(self) -> Dict[str, BaseVisualStim]:
'border': self.border_stim()
}
- def update(self, text: str = ''):
- """Update the task bar to display the given text."""
+ def update(self, text: str = '') -> None:
+ """Update the task bar to display the given text.
+
+ Args:
+ text (str): New spelled text to display. Defaults to ''.
+ """
self.spelled_text = text
self.stim['spelled_text'].text = self.displayed_text()
- def displayed_text(self):
- """Spelled text padded for alignment."""
+ def displayed_text(self) -> str:
+ """Get the spelled text padded for alignment.
+
+ Returns:
+ str: Spelled text padded with spaces to match task_text length.
+ """
diff = len(self.task_text) - len(self.spelled_text)
if (diff > 0):
return self.spelled_text + (' ' * diff)
diff --git a/bcipy/display/demo/components/demo_layouts.py b/bcipy/display/demo/components/demo_layouts.py
index f0e2179f3..5c944f2e0 100644
--- a/bcipy/display/demo/components/demo_layouts.py
+++ b/bcipy/display/demo/components/demo_layouts.py
@@ -267,7 +267,8 @@ def demo_matrix_positions(win: visual.Window):
symbols = alphabet()
norm_layout = centered(parent=win, width_pct=0.7, height_pct=0.75)
- positions = symbol_positions(norm_layout, symbol_set=symbols, rows=5, columns=6)
+ positions = symbol_positions(
+ norm_layout, symbol_set=symbols, rows=5, columns=6)
for sym, pos in zip(symbols, positions):
stim = visual.TextStim(win,
diff --git a/bcipy/display/demo/matrix/demo_calibration_matrix.py b/bcipy/display/demo/matrix/demo_calibration_matrix.py
index 3e737644d..839a69554 100644
--- a/bcipy/display/demo/matrix/demo_calibration_matrix.py
+++ b/bcipy/display/demo/matrix/demo_calibration_matrix.py
@@ -35,7 +35,8 @@
win = init_display_window(window_parameters)
win.recordFrameIntervals = False
-task_bar = CalibrationTaskBar(win, inquiry_count=4, current_index=0, font='Arial')
+task_bar = CalibrationTaskBar(
+ win, inquiry_count=4, current_index=0, font='Arial')
preview_config = PreviewParams(show_preview_inquiry=True,
preview_inquiry_length=2,
preview_inquiry_key_input='return',
diff --git a/bcipy/display/main.py b/bcipy/display/main.py
index 49f5fa074..859093aa6 100644
--- a/bcipy/display/main.py
+++ b/bcipy/display/main.py
@@ -1,7 +1,13 @@
+"""Main display module.
+
+This module provides the core display functionality for BciPy, including base classes
+and utilities for creating and managing visual stimuli in BCI paradigms.
+"""
+
# mypy: disable-error-code="assignment,empty-body"
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union
from psychopy import visual
@@ -13,9 +19,22 @@
class Display(ABC):
- """Display.
-
- Base class for BciPy displays. This defines the logic necessary for task executions that require a display.
+ """Base class for BciPy displays.
+
+ This abstract class defines the core interface and functionality necessary for
+ task executions that require a display. It provides methods for stimulus
+ presentation, timing control, and task management.
+
+ Attributes:
+ window (visual.Window): PsychoPy window for display.
+ timing_clock (Clock): Clock for timing control.
+ experiment_clock (Clock): Clock for experiment timing.
+ stimuli_inquiry (List[str]): List of stimuli to present.
+ stimuli_colors (List[str]): List of colors for each stimulus.
+ stimuli_timing (List[float]): List of presentation durations.
+ task (Any): Task-related information.
+ info_text (List[Any]): Information text to display.
+ first_stim_time (float): Time of first stimulus presentation.
"""
window: visual.Window = None
@@ -30,107 +49,142 @@ class Display(ABC):
@abstractmethod
def do_inquiry(self) -> List[Tuple[str, float]]:
- """Do inquiry.
+ """Perform an inquiry of stimuli.
Animates an inquiry of stimuli and returns a list of stimuli trigger timing.
+
+ Returns:
+ List[Tuple[str, float]]: List of (stimulus, timing) pairs.
"""
...
@abstractmethod
- def wait_screen(self, *args, **kwargs) -> None:
- """Wait Screen.
+ def wait_screen(self, *args: Any, **kwargs: Any) -> None:
+ """Display a wait screen.
Define what happens on the screen when a user pauses a session.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
"""
...
@abstractmethod
- def update_task_bar(self, *args, **kwargs) -> None:
- """Update Task.
+ def update_task_bar(self, *args: Any, **kwargs: Any) -> None:
+ """Update task bar display.
+
+ Update any taskbar-related display items not related to the inquiry.
+ Example: stimuli count 1/200.
- Update any taskbar-related display items not related to the inquiry. Ex. stimuli count 1/200.
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
"""
...
- def schedule_to(self, stimuli: list, timing: list, colors: list) -> None:
- """Schedule To.
+ def schedule_to(self, stimuli: List[str], timing: List[float], colors: List[str]) -> None:
+ """Schedule stimuli elements.
Schedule stimuli elements (works as a buffer) before calling do_inquiry.
+
+ Args:
+ stimuli (List[str]): List of stimuli to present.
+ timing (List[float]): List of presentation durations.
+ colors (List[str]): List of colors for each stimulus.
"""
...
def draw_static(self) -> None:
- """Draw Static.
+ """Draw static elements.
Displays task information not related to the inquiry.
"""
...
- def preview_inquiry(self, *args, **kwargs) -> List[float]:
- """Preview Inquiry.
+ def preview_inquiry(self, *args: Any, **kwargs: Any) -> List[float]:
+ """Preview an inquiry before presentation.
+
+ Display an inquiry or instruction beforehand to the user. This should be called
+ before do_inquiry. This can be used to determine if the desired stimuli is present
+ before displaying them more laboriously or prompting users before the inquiry.
+
+ Note:
+ All stimuli elements (stimuli, timing, colors) must be set on the display
+ before calling this method. This implies something like schedule_to is called.
- Display an inquiry or instruction beforehand to the user. This should be called before do_inquiry.
- This can be used to determine if the desired stimuli is present before displaying them more laboriusly
- or prompting users before the inquiry.
- All stimuli elements (stimuli, timing, colors) must be set on the display before calling this method.
- This implies, something like schedule_to is called.
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ List[float]: List of timing information for the preview.
"""
...
-def init_display_window(parameters):
- """
- Init Display Window.
+def init_display_window(parameters: Dict[str, Any]) -> visual.Window:
+ """Initialize the main display window.
- Function to Initialize main display window
- needed for all later stimuli presentation.
+ Function to initialize main display window needed for all later stimuli presentation.
See Psychopy official documentation for more information and working demos:
http://www.psychopy.org/api/visual/window.html
- """
- # Check is full_screen mode is set and get necessary values
- if parameters['full_screen']:
+ Args:
+ parameters (Dict[str, Any]): Dictionary containing window configuration parameters.
+ Returns:
+ visual.Window: Initialized PsychoPy window for display.
+ """
+ # Check if full_screen mode is set and get necessary values
+ if parameters['full_screen']:
# set window attributes based on resolution
screen_info = get_screen_info()
window_height = screen_info.height
window_width = screen_info.width
-
- # set full screen mode to true (removes os dock, explorer etc.)
full_screen = True
-
- # otherwise, get user defined window attributes
else:
-
# set window attributes directly from parameters file
window_height = parameters['window_height']
window_width = parameters['window_width']
-
- # make sure full screen is set to false
full_screen = False
# Initialize PsychoPy Window for Main Display of Stimuli
display_window = visual.Window(
- size=[window_width,
- window_height],
+ size=[window_width, window_height],
screen=parameters['stim_screen'],
allowGUI=False,
useFBO=False,
fullscr=full_screen,
allowStencil=False,
monitor='mainMonitor',
- winType='pyglet', units='norm', waitBlanking=False,
+ winType='pyglet',
+ units='norm',
+ waitBlanking=False,
color=parameters['background_color'])
- # Return display window to caller
return display_window
class StimuliProperties:
- """"Stimuli Properties.
-
- An encapsulation of properties relevant to core stimuli presentation in a paradigm.
+ """Encapsulation of properties for core stimuli presentation.
+
+ This class manages the properties and configuration for presenting stimuli
+ in a paradigm, including text and image-based stimuli.
+
+ Attributes:
+ stim_font (str): Font to use for text stimuli.
+ stim_pos (Union[Tuple[float, float], List[Tuple[float, float]]]): Position(s) for stimuli.
+ stim_height (float): Height of stimuli.
+ stim_inquiry (List[str]): List of stimuli to present.
+ stim_colors (List[str]): List of colors for each stimulus.
+ stim_timing (List[float]): List of presentation durations.
+ is_txt_stim (bool): Whether stimuli are text-based.
+ stim_length (int): Number of stimuli.
+ sti (Optional[Union[visual.TextStim, visual.ImageStim]]): Stimulus object.
+ prompt_time (Optional[float]): Time to display target prompt.
+ layout (Optional[str]): Layout of stimuli (e.g., 'ALPHABET' or 'QWERTY').
"""
def __init__(
@@ -143,20 +197,19 @@ def __init__(
stim_timing: Optional[List[float]] = None,
is_txt_stim: bool = True,
prompt_time: Optional[float] = None,
- layout: Optional[str] = None):
- """Initialize Stimuli Parameters.
-
- stim_font(List[str]): Ordered list of colors to apply to information stimuli
- stim_pos(Tuple[float, float]): Position on window where the stimuli will be presented
- or a list of positions (ex. for matrix displays)
- stim_height(float): Height of all stimuli
- stim_inquiry(List[str]): Ordered list of text to build stimuli with
- stim_colors(List[str]): Ordered list of colors to apply to stimuli
- stim_timing(List[float]): Ordered list of timing to apply to an inquiry using the stimuli
- is_txt_stim(bool): Whether or not this is a text based stimuli (False implies image based)
- prompt_time(float): Time to display target prompt for at the beginning of inquiry
- layout(str): Layout of stimuli on the screen (ex. 'ALPHABET' or 'QWERTY').
- This is only used for matrix displays.
+ layout: Optional[str] = None) -> None:
+ """Initialize Stimuli Properties.
+
+ Args:
+ stim_font (str): Font to use for text stimuli.
+ stim_pos (Union[Tuple[float, float], List[Tuple[float, float]]]): Position(s) for stimuli.
+ stim_height (float): Height of stimuli.
+ stim_inquiry (Optional[List[str]]): List of stimuli to present. Defaults to None.
+ stim_colors (Optional[List[str]]): List of colors for each stimulus. Defaults to None.
+ stim_timing (Optional[List[float]]): List of presentation durations. Defaults to None.
+ is_txt_stim (bool): Whether stimuli are text-based. Defaults to True.
+ prompt_time (Optional[float]): Time to display target prompt. Defaults to None.
+ layout (Optional[str]): Layout of stimuli. Defaults to None.
"""
self.stim_font = stim_font
self.stim_pos = stim_pos
@@ -171,11 +224,17 @@ def __init__(
self.layout = layout
def build_init_stimuli(self, window: visual.Window) -> Union[visual.TextStim, visual.ImageStim]:
- """"Build Initial Stimuli.
+ """Build initial stimulus object.
This method constructs the stimuli object which can be updated later. This is more
- performant than creating a new stimuli each call. It can create either an image or text stimuli
- based on the boolean self.is_txt_stim.
+ performant than creating a new stimuli each call. It can create either an image or
+ text stimuli based on the boolean self.is_txt_stim.
+
+ Args:
+ window (visual.Window): PsychoPy window for display.
+
+ Returns:
+ Union[visual.TextStim, visual.ImageStim]: The created stimulus object.
"""
if self.is_txt_stim:
self.sti = visual.TextStim(
@@ -185,8 +244,10 @@ def build_init_stimuli(self, window: visual.Window) -> Union[visual.TextStim, vi
text='',
font=self.stim_font,
pos=self.stim_pos,
- wrapWidth=None, colorSpace='rgb',
- opacity=1, depth=-6.0)
+ wrapWidth=None,
+ colorSpace='rgb',
+ opacity=1,
+ depth=-6.0)
else:
self.sti = visual.ImageStim(
win=window,
@@ -198,10 +259,18 @@ def build_init_stimuli(self, window: visual.Window) -> Union[visual.TextStim, vi
class InformationProperties:
- """"Information Properties.
-
- An encapsulation of properties relevant to task information presentation in an RSVP paradigm. This could be
- messaging relevant to feedback or static text to remain on screen not related to task tracking.
+ """Encapsulation of properties for task information presentation.
+
+ This class manages the properties and configuration for displaying task-related
+ information, feedback, and static text in an RSVP paradigm.
+
+ Attributes:
+ info_color (List[str]): List of colors for information text.
+ info_text (List[str]): List of information text to display.
+ info_font (List[str]): List of fonts for information text.
+ info_pos (List[Tuple[float, float]]): List of positions for information text.
+ info_height (List[float]): List of heights for information text.
+ text_stim (List[visual.TextStim]): List of text stimulus objects.
"""
def __init__(
@@ -210,14 +279,15 @@ def __init__(
info_text: List[str],
info_font: List[str],
info_pos: List[Tuple[float, float]],
- info_height: List[float]):
- """Initialize Information Parameters.
-
- info_color(List[str]): Ordered list of colors to apply to information stimuli
- info_text(List[str]): Ordered list of text to apply to information stimuli
- info_font(List[str]): Ordered list of font to apply to information stimuli
- info_pos(Tuple[float, float]): Position on window where the Information stimuli will be presented
- info_height(List[float]): Ordered list of height of Information stimuli
+ info_height: List[float]) -> None:
+ """Initialize Information Properties.
+
+ Args:
+ info_color (List[str]): List of colors for information text.
+ info_text (List[str]): List of information text to display.
+ info_font (List[str]): List of fonts for information text.
+ info_pos (List[Tuple[float, float]]): List of positions for information text.
+ info_height (List[float]): List of heights for information text.
"""
self.info_color = info_color
self.info_text = info_text
@@ -226,9 +296,15 @@ def __init__(
self.info_height = info_height
def build_info_text(self, window: visual.Window) -> List[visual.TextStim]:
- """"Build Information Text.
+ """Build information text stimuli.
Constructs a list of Information stimuli to display.
+
+ Args:
+ window (visual.Window): PsychoPy window for display.
+
+ Returns:
+ List[visual.TextStim]: List of text stimulus objects.
"""
self.text_stim = []
for idx in range(len(self.info_text)):
@@ -239,24 +315,39 @@ def build_info_text(self, window: visual.Window) -> List[visual.TextStim]:
text=self.info_text[idx],
font=self.info_font[idx],
pos=self.info_pos[idx],
- wrapWidth=None, colorSpace='rgb',
- opacity=1, depth=-6.0))
+ wrapWidth=None,
+ colorSpace='rgb',
+ opacity=1,
+ depth=-6.0))
return self.text_stim
class ButtonPressMode(Enum):
- """Represents the possible meanings for a button press (when using an Inquiry Preview.)"""
+ """Represents the possible meanings for a button press.
+
+ Used when implementing Inquiry Preview functionality to determine the
+ action to take based on user input.
+ """
NOTHING = 0
ACCEPT = 1
REJECT = 2
class PreviewParams(NamedTuple):
- """Parameters relevant for the Inquiry Preview functionality.
-
- Create from an existing Parameters instance using:
- >>> parameters.instantiate(PreviewParams)
+ """Parameters for Inquiry Preview functionality.
+
+ This class defines the configuration parameters needed for the Inquiry Preview
+ feature, which allows users to preview stimuli before presentation.
+
+ Attributes:
+ show_preview_inquiry (bool): Whether to show preview.
+ preview_inquiry_length (float): Duration of preview.
+ preview_inquiry_key_input (str): Key to use for preview input.
+ preview_inquiry_progress_method (int): Method for handling preview progress.
+ preview_inquiry_isi (float): Inter-stimulus interval for preview.
+ preview_box_text_size (float): Text size for preview box.
"""
+
show_preview_inquiry: bool
preview_inquiry_length: float
preview_inquiry_key_input: str
@@ -266,13 +357,24 @@ class PreviewParams(NamedTuple):
@property
def button_press_mode(self) -> ButtonPressMode:
- """Mode indicated by the inquiry progress method."""
+ """Get the button press mode from the progress method.
+
+ Returns:
+ ButtonPressMode: The mode indicated by the progress method.
+ """
return ButtonPressMode(self.preview_inquiry_progress_method)
def get_button_handler_class(
mode: ButtonPressMode) -> Type[ButtonPressHandler]:
- """Get the appropriate handler constructor for the given button press mode."""
+ """Get the appropriate button handler class for the given mode.
+
+ Args:
+ mode (ButtonPressMode): The button press mode to handle.
+
+ Returns:
+ Type[ButtonPressHandler]: The appropriate handler class.
+ """
mapping = {
ButtonPressMode.NOTHING: PreviewOnlyButtonPressHandler,
ButtonPressMode.ACCEPT: AcceptButtonPressHandler,
@@ -283,7 +385,15 @@ def get_button_handler_class(
def init_preview_button_handler(params: PreviewParams,
experiment_clock: Clock) -> ButtonPressHandler:
- """"Returns a button press handler for inquiry preview."""
+ """Initialize a button press handler for inquiry preview.
+
+ Args:
+ params (PreviewParams): Preview configuration parameters.
+ experiment_clock (Clock): Clock for experiment timing.
+
+ Returns:
+ ButtonPressHandler: Configured button press handler.
+ """
make_handler = get_button_handler_class(params.button_press_mode)
return make_handler(max_wait=params.preview_inquiry_length,
key_input=params.preview_inquiry_key_input,
@@ -291,6 +401,14 @@ def init_preview_button_handler(params: PreviewParams,
class VEPStimuliProperties(StimuliProperties):
+ """Properties for VEP (Visual Evoked Potential) stimuli.
+
+ This class extends StimuliProperties to provide specific functionality
+ for VEP-based paradigms.
+
+ Attributes:
+ animation_seconds (float): Duration of animation.
+ """
def __init__(self,
stim_font: str,
@@ -300,12 +418,18 @@ def __init__(self,
stim_color: List[str],
inquiry: List[List[Any]],
stim_length: int = 1,
- animation_seconds: float = 1.0):
- """Initialize VEP Stimuli Parameters.
- stim_color(List[str]): Ordered list of colors to apply to VEP stimuli
- stim_font(str): Font to apply to all VEP stimuli
- stim_pos(List[Tuple[float, float]]): Position on the screen where to present to VEP text
- stim_height(float): Height of all VEP text stimuli
+ animation_seconds: float = 1.0) -> None:
+ """Initialize VEP Stimuli Properties.
+
+ Args:
+ stim_font (str): Font to use for text stimuli.
+ stim_pos (List[Tuple[float, float]]): Positions for stimuli.
+ stim_height (float): Height of stimuli.
+ timing (List[float]): List of presentation durations.
+ stim_color (List[str]): List of colors for each stimulus.
+ inquiry (List[List[Any]]): List of inquiry stimuli.
+ stim_length (int, optional): Number of stimuli. Defaults to 1.
+ animation_seconds (float, optional): Duration of animation. Defaults to 1.0.
"""
# static properties
self.stim_font = stim_font
@@ -317,11 +441,15 @@ def __init__(self,
# dynamic property. List of length 3. 1. prompt; 2. fixation; 3. inquiry
self.stim_timing = timing
- # dynamic properties, must be a a list of lists where each list is a different box
+ # dynamic properties, must be a list of lists where each list is a different box
self.stim_colors = stim_color
self.stim_inquiry = inquiry
self.animation_seconds = animation_seconds
def build_init_stimuli(self, window: visual.Window) -> None:
- """"Build Initial Stimuli."""
+ """Build initial VEP stimuli.
+
+ Args:
+ window (visual.Window): PsychoPy window for display.
+ """
...
diff --git a/bcipy/display/paradigm/matrix/README.md b/bcipy/display/paradigm/matrix/README.md
index 0fbdf4a89..efbb7d6de 100644
--- a/bcipy/display/paradigm/matrix/README.md
+++ b/bcipy/display/paradigm/matrix/README.md
@@ -6,7 +6,7 @@ The matrix display presents a list of symbols in a grid format. Grid items are e
A matrix display needs a psychopy Window, core.Clock, and configuration for the stimuli (StimuliProperties), task_bar (TaskBar), and information elements (InformationProperties).
-```
+```python
from psychopy import core
import bcipy.display.components.layout as layout
@@ -45,7 +45,7 @@ The number of rows and columns can be specified by using the provided parameters
When using a task, the `matrix_rows` and `matrix_columns` parameters are used for customization.
-```
+```python
matrix_display = MatrixDisplay(win,
experiment_clock,
stim_properties,
@@ -55,14 +55,13 @@ matrix_display = MatrixDisplay(win,
columns=7)
```
-
## Layout
Symbol positions are calculated when the display is initialized. The grid will be centered within the window. By default the grid will take up 75% of the width, and 80% of the height, or whichever is smaller depending on the aspect ratio of your monitor. These values can be adjusted by provided a width_pct and height_pct parameter.
When using a task, the `matrix_width` parameter is used for customization.
-```
+```python
# determine matrix height based on the size of the task_bar
matrix_height_pct = 1 - (2 * task_bar.height_pct)
matrix_display = MatrixDisplay(win,
@@ -93,9 +92,9 @@ You may need to do some trial and error to determine the best matrix configurati
For tasks which use the matrix display, the following parameters are recommended:
-```
+```python
time_fixation: 2
stim_height: 0.17
task_height: 0.1
task_padding: 0.05
-```
\ No newline at end of file
+```
diff --git a/bcipy/display/paradigm/matrix/display.py b/bcipy/display/paradigm/matrix/display.py
index 4e70ebeb3..52432c969 100644
--- a/bcipy/display/paradigm/matrix/display.py
+++ b/bcipy/display/paradigm/matrix/display.py
@@ -1,6 +1,12 @@
-"""Display for presenting stimuli in a grid."""
+"""Display for presenting stimuli in a grid.
+
+This module provides functionality for displaying and managing matrix-style stimuli
+presentations, commonly used in BCI paradigms. It handles the layout, timing, and
+animation of stimuli in a grid format.
+"""
+
import logging
-from typing import Dict, List, NamedTuple, Optional, Tuple
+from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from psychopy import core, visual
@@ -19,7 +25,13 @@
class SymbolDuration(NamedTuple):
- """Represents a symbol and its associated duration to display"""
+ """Represents a symbol and its associated duration to display.
+
+ Attributes:
+ symbol (str): The symbol to display.
+ duration (float): Duration in seconds to display the symbol.
+ color (str): Color to display the symbol in. Defaults to 'white'.
+ """
symbol: str
duration: float
color: str = 'white'
@@ -30,59 +42,82 @@ class MatrixDisplay(Display):
Animates display objects in matrix grid common to any Matrix task.
- NOTE: The following are recommended parameter values for matrix experiments:
-
- time_fixation: 2
- stim_pos_x: -0.6
- stim_pos_y: 0.4
- stim_height: 0.17
+ Attributes:
+ window (visual.Window): PsychoPy window for display.
+ stimuli_inquiry (List[str]): List of stimuli to present.
+ stimuli_timing (List[float]): List of timing values for each stimulus.
+ stimuli_colors (List[str]): List of colors for each stimulus.
+ stimuli_font (str): Font to use for text stimuli.
+ symbol_set (Optional[List[str]]): Set of symbols to display.
+ sort_order (Callable[[str], int]): Function to determine symbol order.
+ grid_stimuli_height (float): Height of grid stimuli.
+ positions (Dict[int, Tuple[float, float]]): Positions for each symbol.
+ grid_color (str): Default color for grid elements.
+ start_opacity (float): Initial opacity for grid elements.
+ highlight_opacity (float): Opacity for highlighted elements.
+ full_grid_opacity (float): Opacity for full grid display.
+ first_run (bool): Whether this is the first run.
+ first_stim_time (Optional[float]): Time of first stimulus.
+ trigger_type (str): Type of trigger to use.
+ _timing (List[Tuple[str, float]]): List of timing information.
+ experiment_clock (core.Clock): Clock for timing.
+ task_bar (TaskBar): Task bar component.
+ info_text (List[visual.TextStim]): Information text components.
+ stim_registry (Dict[str, visual.TextStim]): Registry of stimuli.
+ should_prompt_target (bool): Whether to prompt for target.
+ preview_params (Optional[PreviewParams]): Preview configuration.
+ preview_button_handler (Optional[Any]): Handler for preview buttons.
+ preview_accepted (bool): Whether preview was accepted.
+
+ Note:
+ The following are recommended parameter values for matrix experiments:
+ - time_fixation: 2
+ - stim_pos_x: -0.6
+ - stim_pos_y: 0.4
+ - stim_height: 0.17
"""
- def __init__(self,
- window: visual.Window,
- experiment_clock: core.Clock,
- stimuli: StimuliProperties,
- task_bar: TaskBar,
- info: InformationProperties,
- rows: int = 5,
- columns: int = 6,
- width_pct: float = 0.75,
- height_pct: float = 0.8,
- trigger_type: str = 'text',
- symbol_set: Optional[List[str]] = alphabet(),
- should_prompt_target: bool = True,
- preview_config: Optional[PreviewParams] = None):
+ def __init__(
+ self,
+ window: visual.Window,
+ experiment_clock: core.Clock,
+ stimuli: StimuliProperties,
+ task_bar: TaskBar,
+ info: InformationProperties,
+ rows: int = 5,
+ columns: int = 6,
+ width_pct: float = 0.75,
+ height_pct: float = 0.8,
+ trigger_type: str = 'text',
+ symbol_set: Optional[List[str]] = alphabet(),
+ should_prompt_target: bool = True,
+ preview_config: Optional[PreviewParams] = None
+ ) -> None:
"""Initialize Matrix display parameters and objects.
- PARAMETERS:
- ----------
- # Experiment
- window(visual.Window): PsychoPy Window
- experiment_clock(core.Clock): Clock used to timestamp display onsets
-
- # Stimuli
- stimuli(StimuliProperties): attributes used for inquiries
-
- # Task
- task_bar(TaskBar): used for task tracking. Ex. 1/100
-
- # Info
- info(InformationProperties): attributes to display informational stimuli alongside task and inquiry stimuli.
-
- trigger_type(str) default 'image': defines the calibration trigger type for the display at the beginning of any
- task. This will be used to reconcile timing differences between acquisition and the display.
- symbol_set default = none : subset of stimuli to be highlighted during an inquiry
- should_prompt_target(bool): when True prompts for the target symbol. Assumes that this is
- the first symbol of each inquiry. For example: [target, fixation, *stim].
- sort_order - optional function to define the position index for each
- symbol. Using a custom function it is possible to skip a position.
- preview_config - optional configuration for previewing inquiries
+ Args:
+ window (visual.Window): PsychoPy Window.
+ experiment_clock (core.Clock): Clock used to timestamp display onsets.
+ stimuli (StimuliProperties): Attributes used for inquiries.
+ task_bar (TaskBar): Used for task tracking. Ex. 1/100.
+ info (InformationProperties): Attributes to display informational stimuli.
+ rows (int): Number of rows in the matrix. Defaults to 5.
+ columns (int): Number of columns in the matrix. Defaults to 6.
+ width_pct (float): Width percentage of the display. Defaults to 0.75.
+ height_pct (float): Height percentage of the display. Defaults to 0.8.
+ trigger_type (str): Defines the calibration trigger type. Defaults to 'text'.
+ symbol_set (Optional[List[str]]): Subset of stimuli to be highlighted.
+ Defaults to alphabet().
+ should_prompt_target (bool): When True prompts for the target symbol.
+ Defaults to True.
+ preview_config (Optional[PreviewParams]): Configuration for previewing inquiries.
+ Defaults to None.
"""
self.window = window
- self.stimuli_inquiry = []
- self.stimuli_timing = []
- self.stimuli_colors = []
+ self.stimuli_inquiry: List[str] = []
+ self.stimuli_timing: List[float] = []
+ self.stimuli_colors: List[str] = []
self.stimuli_font = stimuli.stim_font
assert stimuli.is_txt_stim, "Matrix display is a text only display"
@@ -92,9 +127,11 @@ def __init__(self,
# Set position and parameters for grid of alphabet
self.grid_stimuli_height = stimuli.stim_height
- display_container = layout.centered(parent=window,
- width_pct=width_pct,
- height_pct=height_pct)
+ display_container = layout.centered(
+ parent=window,
+ width_pct=width_pct,
+ height_pct=height_pct
+ )
self.positions = symbol_positions(
display_container, rows, columns, symbol_set)
@@ -107,7 +144,7 @@ def __init__(self,
self.first_run = True
self.first_stim_time = None
self.trigger_type = trigger_type
- self._timing = []
+ self._timing: List[Tuple[str, float]] = []
self.experiment_clock = experiment_clock
@@ -127,10 +164,23 @@ def __init__(self,
)
logger.info(f"Matrix center position: {display_container.center}")
- def build_sort_order(self, stimuli: StimuliProperties) -> List[str]:
- """Build the symbol set for the display."""
+ def build_sort_order(self, stimuli: StimuliProperties) -> Callable[[str], int]:
+ """Build the symbol set for the display.
+
+ Args:
+ stimuli (StimuliProperties): Properties containing layout information.
+
+ Returns:
+ Callable[[str], int]: Function that returns the index for a given symbol.
+
+ Raises:
+ ValueError: If layout or symbol set is not recognized.
+ """
if stimuli.layout == 'ALP':
- return self.symbol_set.index
+ if self.symbol_set:
+ return self.symbol_set.index
+ else:
+ raise ValueError('Symbol set not defined')
elif stimuli.layout == 'QWERTY':
logger.info('Using QWERTY layout')
return qwerty_order()
@@ -142,7 +192,14 @@ def build_sort_order(self, stimuli: StimuliProperties) -> List[str]:
@property
def stim_positions(self) -> Dict[str, Tuple[float, float]]:
- """Returns a dict with the position for each stim"""
+ """Get positions for each stimulus.
+
+ Returns:
+ Dict[str, Tuple[float, float]]: Dictionary mapping symbols to positions.
+
+ Raises:
+ AssertionError: If stim_registry is not initialized.
+ """
assert self.stim_registry, "stim_registry not yet initialized"
return {
sym: tuple(stim.pos)
@@ -151,13 +208,18 @@ def stim_positions(self) -> Dict[str, Tuple[float, float]]:
@property
def preview_enabled(self) -> bool:
- """Should the inquiry preview be enabled."""
- return self.preview_params and self.preview_params.show_preview_inquiry
+ """Check if inquiry preview should be enabled.
+
+ Returns:
+ bool: True if preview is enabled, False otherwise.
+ """
+ return bool(self.preview_params and self.preview_params.show_preview_inquiry)
def capture_grid_screenshot(self, file_path: str) -> None:
- """Capture Grid Screenshot.
+ """Capture a screenshot of the current display.
- Capture a screenshot of the current display and save it to the specified filename.
+ Args:
+ file_path (str): Path where to save the screenshot.
"""
# draw the grid and flip the window
self.draw_grid(opacity=self.full_grid_opacity)
@@ -171,13 +233,17 @@ def capture_grid_screenshot(self, file_path: str) -> None:
capture.save(f'{file_path}/{MATRIX_IMAGE_FILENAME}')
self.task_bar.current_index = tmp_task_bar
- def schedule_to(self, stimuli: list, timing: list, colors: list) -> None:
+ def schedule_to(self, stimuli: List[str], timing: List[float], colors: Optional[List[str]] = None) -> None:
"""Schedule stimuli elements (works as a buffer).
Args:
- stimuli(list[string]): list of stimuli text / name
- timing(list[float]): list of timings of stimuli
- colors(list[string]): list of colors
+ stimuli (List[str]): List of stimuli text / name.
+ timing (List[float]): List of timings of stimuli.
+ colors (Optional[List[str]]): List of colors. Defaults to None.
+
+ Raises:
+ AssertionError: If lengths of stimuli and timing don't match,
+ or if colors are provided but lengths don't match.
"""
assert len(stimuli) == len(
timing), "each stimuli must have a timing value"
@@ -191,27 +257,36 @@ def schedule_to(self, stimuli: list, timing: list, colors: list) -> None:
self.stimuli_colors = [self.grid_color] * len(stimuli)
def symbol_durations(self) -> List[SymbolDuration]:
- """Symbols associated with their duration for the currently configured
- stimuli_inquiry."""
+ """Get symbols associated with their duration for the currently configured stimuli_inquiry.
+
+ Returns:
+ List[SymbolDuration]: List of symbol durations.
+ """
return [
SymbolDuration(*sti) for sti in zip(
self.stimuli_inquiry, self.stimuli_timing, self.stimuli_colors)
]
- def add_timing(self, stimuli: str, stamp: Optional[float] = None):
+ def add_timing(self, stimuli: str, stamp: Optional[float] = None) -> None:
"""Add a new timing entry using the stimuli as a label.
- Useful as a callback function to register a marker at the time it is
- first displayed."""
+ Args:
+ stimuli (str): Label for the timing entry.
+ stamp (Optional[float]): Timestamp. If None, uses current time.
+ """
stamp = stamp or self.experiment_clock.getTime()
self._timing.append((stimuli, stamp))
- def reset_timing(self):
+ def reset_timing(self) -> None:
"""Reset the trigger timing."""
self._timing = []
def do_inquiry(self) -> List[Tuple[str, float]]:
- """Animates an inquiry of stimuli and returns a list of stimuli trigger timing."""
+ """Animate an inquiry of stimuli and return timing information.
+
+ Returns:
+ List[Tuple[str, float]]: List of stimuli trigger timing.
+ """
self.preview_accepted = True
self.reset_timing()
symbol_durations = self.symbol_durations()
@@ -234,36 +309,44 @@ def do_inquiry(self) -> List[Tuple[str, float]]:
return self._timing
def build_grid(self) -> Dict[str, visual.TextStim]:
- """Build the text stimuli to populate the grid."""
+ """Build the text stimuli to populate the grid.
+ Returns:
+ Dict[str, visual.TextStim]: Dictionary mapping symbols to text stimuli.
+ """
grid = {}
- for sym in self.symbol_set:
- pos_index = self.sort_order(sym)
- pos = self.positions[pos_index]
- grid[sym] = visual.TextStim(win=self.window,
- font=self.stimuli_font,
- text=sym,
- color=self.grid_color,
- opacity=self.start_opacity,
- pos=pos,
- height=self.grid_stimuli_height)
+ if self.symbol_set:
+ for sym in self.symbol_set:
+ pos_index = self.sort_order(sym)
+ pos = self.positions[pos_index]
+ grid[sym] = visual.TextStim(
+ win=self.window,
+ font=self.stimuli_font,
+ text=sym,
+ color=self.grid_color,
+ opacity=self.start_opacity,
+ pos=pos,
+ height=self.grid_stimuli_height
+ )
return grid
- def draw_grid(self,
- opacity: float = 1,
- color: Optional[str] = 'white',
- highlight: Optional[List[str]] = None,
- highlight_color: Optional[str] = None):
+ def draw_grid(
+ self,
+ opacity: float = 1,
+ color: Optional[str] = 'white',
+ highlight: Optional[List[str]] = None,
+ highlight_color: Optional[str] = None
+ ) -> None:
"""Draw the grid.
- Parameters
- ----------
- opacity - opacity for each item in the matrix
- color - optional color for each item in the matrix
- highlight - optional list of stim labels to be highlighted
- (rendered using the highlight_opacity).
- highlight_color - optional color to use for rendering the
- highlighted stim.
+ Args:
+ opacity (float): Opacity for each item in the matrix. Defaults to 1.
+ color (Optional[str]): Optional color for each item in the matrix.
+ Defaults to 'white'.
+ highlight (Optional[List[str]]): Optional list of stim labels to be highlighted.
+ Defaults to None.
+ highlight_color (Optional[str]): Optional color to use for rendering the
+ highlighted stim. Defaults to None.
"""
for symbol, stim in self.stim_registry.items():
should_highlight = highlight and (symbol in highlight)
@@ -273,32 +356,32 @@ def draw_grid(self,
if highlight_color and should_highlight else color)
stim.draw()
- def prompt_target(self, target: SymbolDuration) -> float:
- """Present the target for the configured length of time. Records the
- stimuli timing information.
+ def prompt_target(self, target: SymbolDuration) -> None:
+ """Present the target for the configured length of time.
- Parameters
- ----------
- target - (symbol, duration) tuple
+ Args:
+ target (SymbolDuration): Target symbol and its duration.
"""
# register any timing and marker callbacks
self.window.callOnFlip(self.add_timing, target.symbol)
- self.draw(grid_opacity=self.start_opacity,
- duration=target.duration,
- highlight=[target.symbol],
- highlight_color=target.color)
+ self.draw(
+ grid_opacity=self.start_opacity,
+ duration=target.duration,
+ highlight=[target.symbol],
+ highlight_color=target.color
+ )
def preview_inquiry(self, stimuli: List[SymbolDuration]) -> bool:
- """"Preview the inquiry and handle any button presses.
- Parameters
- ----------
- stimuli - list of stimuli to highlight (will be flashed in next inquiry)
-
- Returns
- -------
- boolean indicating whether the participant would like to proceed
- with the inquiry (True) or reject the inquiry (False) and
- go on to the next one.
+ """Preview the inquiry and handle any button presses.
+
+ Args:
+ stimuli (List[SymbolDuration]): List of stimuli to highlight.
+
+ Returns:
+ bool: True if participant wants to proceed, False to reject.
+
+ Raises:
+ AssertionError: If preview is not enabled or button handler not initialized.
"""
assert self.preview_enabled, "Preview feature not enabled."
assert self.preview_button_handler, "Button handler must be initialized"
@@ -312,90 +395,115 @@ def preview_inquiry(self, stimuli: List[SymbolDuration]) -> bool:
if handler.has_response():
self.add_timing(handler.response_label, handler.response_timestamp)
- self.draw(grid_opacity=self.start_opacity,
- duration=self.preview_params.preview_inquiry_isi)
+ if self.preview_params:
+ self.draw(
+ grid_opacity=self.start_opacity,
+ duration=self.preview_params.preview_inquiry_isi
+ )
return handler.accept_result()
def draw_preview(self, stimuli: List[SymbolDuration]) -> None:
- """Draw the inquiry preview by highlighting all of the symbols in the
- list."""
- self.draw(grid_opacity=self.start_opacity,
- highlight=[stim.symbol for stim in stimuli])
-
- def draw(self,
- grid_opacity: float,
- grid_color: Optional[str] = None,
- duration: Optional[float] = None,
- highlight: Optional[List[str]] = None,
- highlight_color: Optional[str] = None):
+ """Draw the inquiry preview by highlighting all of the symbols in the list.
+
+ Args:
+ stimuli (List[SymbolDuration]): List of stimuli to highlight.
+ """
+ self.draw(
+ grid_opacity=self.start_opacity,
+ highlight=[stim.symbol for stim in stimuli]
+ )
+
+ def draw(
+ self,
+ grid_opacity: float,
+ grid_color: Optional[str] = None,
+ duration: Optional[float] = None,
+ highlight: Optional[List[str]] = None,
+ highlight_color: Optional[str] = None
+ ) -> None:
"""Draw all screen elements and flip the window.
- Parameters
- ----------
- grid_opacity - opacity value to use on all grid symbols
- grid_color - optional color to use for all grid symbols
- duration - optional seconds to wait after flipping the window.
- highlight - optional list of symbols to highlight in the grid.
- highlight_color - optional color to use for rendering the
- highlighted stim.
+ Args:
+ grid_opacity (float): Opacity value to use on all grid symbols.
+ grid_color (Optional[str]): Optional color to use for all grid symbols.
+ Defaults to None.
+ duration (Optional[float]): Optional seconds to wait after flipping the window.
+ Defaults to None.
+ highlight (Optional[List[str]]): Optional list of symbols to highlight in the grid.
+ Defaults to None.
+ highlight_color (Optional[str]): Optional color to use for rendering the
+ highlighted stim. Defaults to None.
"""
- self.draw_grid(opacity=grid_opacity,
- color=grid_color or self.grid_color,
- highlight=highlight,
- highlight_color=highlight_color)
+ self.draw_grid(
+ opacity=grid_opacity,
+ color=grid_color or self.grid_color,
+ highlight=highlight,
+ highlight_color=highlight_color
+ )
self.draw_components()
self.window.flip()
if duration:
core.wait(duration)
- def animate_scp(self, fixation: SymbolDuration,
- stimuli: List[SymbolDuration]) -> None:
+ def animate_scp(self, fixation: SymbolDuration, stimuli: List[SymbolDuration]) -> None:
"""Animate the given stimuli using single character presentation.
- Flashes each stimuli in stimuli_inquiry for their respective flash
- times and records the timing information.
+ Args:
+ fixation (SymbolDuration): Fixation symbol and duration.
+ stimuli (List[SymbolDuration]): List of stimuli to animate.
"""
-
# Flashing the grid at full opacity is considered fixation.
self.window.callOnFlip(self.add_timing, fixation.symbol)
- self.draw(grid_opacity=self.full_grid_opacity,
- grid_color=(fixation.color if self.should_prompt_target else
- self.grid_color),
- duration=fixation.duration / 2)
- self.draw(grid_opacity=self.start_opacity,
- duration=fixation.duration / 2)
+ self.draw(
+ grid_opacity=self.full_grid_opacity,
+ grid_color=(fixation.color if self.should_prompt_target else
+ self.grid_color),
+ duration=fixation.duration / 2
+ )
+ self.draw(
+ grid_opacity=self.start_opacity,
+ duration=fixation.duration / 2
+ )
for stim in stimuli:
self.window.callOnFlip(self.add_timing, stim.symbol)
- self.draw(grid_opacity=self.start_opacity,
- duration=stim.duration,
- highlight=[stim.symbol],
- highlight_color=stim.color)
+ self.draw(
+ grid_opacity=self.start_opacity,
+ duration=stim.duration,
+ highlight=[stim.symbol],
+ highlight_color=stim.color
+ )
self.draw(self.start_opacity)
def wait_screen(self, message: str, message_color: str) -> None:
- """Wait Screen.
+ """Display a wait screen with a message.
- Define what happens on the screen when a user pauses a session.
+ Args:
+ message (str): Message to display.
+ message_color (str): Color of the message.
"""
self.draw_components()
# Construct the wait message
- wait_message = visual.TextStim(win=self.window,
- font=self.stimuli_font,
- text=message,
- height=.1,
- color=message_color,
- pos=(0, -.5),
- wrapWidth=2)
+ wait_message = visual.TextStim(
+ win=self.window,
+ font=self.stimuli_font,
+ text=message,
+ height=.1,
+ color=message_color,
+ pos=(0, -.5),
+ wrapWidth=2
+ )
# try adding the BciPy logo to the wait screen
try:
- wait_logo = visual.ImageStim(self.window,
- image=BCIPY_LOGO_PATH,
- pos=(0, .25),
- mask=None,
- ori=0.0)
+ wait_logo = visual.ImageStim(
+ self.window,
+ image=BCIPY_LOGO_PATH,
+ pos=(0, .25),
+ mask=None,
+ ori=0.0
+ )
wait_logo.size = resize_image(BCIPY_LOGO_PATH, self.window.size, 1)
wait_logo.draw()
@@ -422,29 +530,29 @@ def draw_components(self) -> None:
info.draw()
def update_task_bar(self, text: str = '') -> None:
- """Update Task.
-
- Update any task related display items not related to the inquiry. Ex. stimuli count 1/200.
+ """Update task related display items.
- PARAMETERS:
-
- text: text for task
+ Args:
+ text (str): Text for task. Defaults to ''.
"""
if self.task_bar:
self.task_bar.update(text)
def _trigger_pulse(self) -> None:
- """Trigger Pulse.
+ """Send a calibration trigger pulse.
This method uses a calibration trigger to determine any functional
- offsets needed for operation with this display. By setting the first_stim_time and searching for the
- same stimuli output to the marker stream, the offsets between these proceses can be reconciled at the
- beginning of an experiment. If drift is detected in your experiment, more frequent pulses and offset
- correction may be required.
+ offsets needed for operation with this display. By setting the first_stim_time
+ and searching for the same stimuli output to the marker stream, the offsets
+ between these processes can be reconciled at the beginning of an experiment.
+ If drift is detected in your experiment, more frequent pulses and offset
+ correction may be required.
"""
- calibration_time = _calibration_trigger(self.experiment_clock,
- trigger_type=self.trigger_type,
- display=self.window)
+ calibration_time = _calibration_trigger(
+ self.experiment_clock,
+ trigger_type=self.trigger_type,
+ display=self.window
+ )
# set the first stim time if not present and first_run to False
if not self.first_stim_time:
diff --git a/bcipy/display/paradigm/matrix/layout.py b/bcipy/display/paradigm/matrix/layout.py
index 994642bf4..8267a3354 100644
--- a/bcipy/display/paradigm/matrix/layout.py
+++ b/bcipy/display/paradigm/matrix/layout.py
@@ -1,4 +1,10 @@
-"""Functions for calculating matrix layouts"""
+"""Functions for calculating matrix layouts.
+
+This module provides functionality for calculating and managing the layout of symbols
+in a matrix-style grid format, commonly used in BCI paradigms. It handles the
+positioning and spacing of elements based on window dimensions and layout requirements.
+"""
+
from typing import List, Optional, Tuple
from bcipy.display.components.layout import (Layout, above, below, left_of,
@@ -6,26 +12,35 @@
scaled_width)
-def symbol_positions(container: Layout,
- rows: int,
- columns: int,
- symbol_set: List[str],
- max_spacing: Optional[float] = None) -> List[Tuple[float, float]]:
- """Compute the positions for arranging a number of symbols in a grid
- layout.
+def symbol_positions(
+ container: Layout,
+ rows: int,
+ columns: int,
+ symbol_set: List[str],
+ max_spacing: Optional[float] = None
+) -> List[Tuple[float, float]]:
+ """Compute the positions for arranging a number of symbols in a grid layout.
+
+ This function calculates the positions for placing symbols in a grid format,
+ taking into account window dimensions, aspect ratio, and spacing requirements.
+ The grid is centered in the container and positions are returned in row-major order.
- Parameters
- ----------
- container - container in which the grid should be placed; must have a
+ Args:
+ container (Layout): Container in which the grid should be placed; must have a
visual.Window parent, which is used to determine the aspect ratio.
- rows - number of rows in the grid
- columns - number of columns in the grid
- symbol_set - list of symbols to place in the grid
- max_spacing - optional max spacing (in layout units) in the height
- direction; width will be normalized to this value if provided
- Returns
- -------
- list of (x,y) tuples with (rows * columns) positions in row,col order
+ rows (int): Number of rows in the grid.
+ columns (int): Number of columns in the grid.
+ symbol_set (List[str]): List of symbols to place in the grid.
+ max_spacing (Optional[float]): Optional max spacing (in layout units) in the height
+ direction; width will be normalized to this value if provided. Defaults to None.
+
+ Returns:
+ List[Tuple[float, float]]: List of (x,y) tuples with (rows * columns) positions
+ in row-major order.
+
+ Raises:
+ AssertionError: If container has no parent, if rows or columns are less than 1,
+ or if there are not enough positions for all symbols.
"""
assert container.parent, "Container must have a parent"
assert rows >= 1 and columns >= 1, "There must be at least one row and one column"
diff --git a/bcipy/display/paradigm/rsvp/display.py b/bcipy/display/paradigm/rsvp/display.py
index e96a28d33..2c075edad 100644
--- a/bcipy/display/paradigm/rsvp/display.py
+++ b/bcipy/display/paradigm/rsvp/display.py
@@ -1,6 +1,13 @@
+"""RSVP display module.
+
+This module provides the base RSVP (Rapid Serial Visual Presentation) display implementation
+which handles the visual presentation of stimuli during RSVP tasks. It provides core functionality
+for stimulus presentation, timing control, and inquiry management.
+"""
+
import logging
import os.path as path
-from typing import List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from psychopy import core, visual
@@ -18,13 +25,45 @@
class RSVPDisplay(Display):
"""RSVP Display Object for inquiry Presentation.
- Animates display objects common to any RSVP task.
+ Animates display objects common to any RSVP task. Handles stimulus presentation,
+ timing control, and inquiry management for RSVP-based BCI paradigms.
+
+ Attributes:
+ window (visual.Window): PsychoPy window for display.
+ window_size (Tuple[float, float]): Size of the display window.
+ refresh_rate (float): Display refresh rate.
+ stimuli_inquiry (List[str]): List of stimuli to present.
+ stimuli_colors (List[str]): List of colors for each stimulus.
+ stimuli_timing (List[float]): List of presentation durations.
+ stimuli_font (str): Font to use for text stimuli.
+ textbox_font (str): Font to use for text boxes.
+ stimuli_height (float): Height of stimuli.
+ stimuli_pos (Tuple[float, float]): Position of stimuli.
+ is_txt_stim (bool): Whether stimuli are text-based.
+ stim_length (int): Length of stimulus list.
+ full_screen (bool): Whether display is fullscreen.
+ preview_params (Optional[PreviewParams]): Preview configuration.
+ preview_button_handler (Optional[Any]): Handler for preview buttons.
+ preview_accepted (bool): Whether preview was accepted.
+ staticPeriod (core.StaticPeriod): Clock for static timing.
+ first_run (bool): Whether this is the first run.
+ first_stim_time (Optional[float]): Time of first stimulus.
+ trigger_type (str): Type of trigger to use.
+ trigger_callback (TriggerCallback): Callback for triggers.
+ experiment_clock (Clock): Clock for experiment timing.
+ first_stim_callback (Callable): Callback for first stimulus.
+ size_list_sti (List[float]): List of stimulus sizes.
+ space_char (str): Character to use for spaces.
+ task_bar (TaskBar): Task bar component.
+ info (InformationProperties): Information display properties.
+ info_text (List[visual.TextStim]): Text stimuli for information.
+ sti (List[visual.TextStim]): Initial stimuli objects.
"""
def __init__(
self,
window: visual.Window,
- static_clock,
+ static_clock: core.StaticPeriod,
experiment_clock: Clock,
stimuli: StimuliProperties,
task_bar: TaskBar,
@@ -32,34 +71,21 @@ def __init__(
preview_config: Optional[PreviewParams] = None,
trigger_type: str = 'image',
space_char: str = SPACE_CHAR,
- full_screen: bool = False):
+ full_screen: bool = False) -> None:
"""Initialize RSVP display parameters and objects.
- PARAMETERS:
- ----------
- # Experiment
- window(visual.Window): PsychoPy Window
- static_clock(core.MonotonicClock): Used to schedule static periods of display time
- experiment_clock(Clock): Clock used to timestamp display onsets
-
- # Stimuli
- stimuli(StimuliProperties): attributes used for inquiries
-
- # Task
- task_bar(TaskBar): used for task tracking. Ex. 1/100
-
- # Info
- info(InformationProperties): attributes to display informational stimuli alongside task and inquiry stimuli.
-
- # Preview Inquiry
- preview_config(PreviewParams) Optional: parameters used to specify the behavior for displaying a preview
- of upcoming stimuli defined via self.stimuli(StimuliProperties). If None a preview is not displayed.
-
- trigger_type(str) default 'image': defines the calibration trigger type for the display at the beginning of any
- task. This will be used to reconcile timing differences between acquisition and the display.
- space_char(str) default SPACE_CHAR: defines the space character to use in the RSVP inquiry.
- full_screen(bool) default False: Whether or not the window is set to a full screen dimension. Used for
- scaling display items as needed.
+ Args:
+ window (visual.Window): PsychoPy Window for display.
+ static_clock (core.StaticPeriod): Used to schedule static periods of display time.
+ experiment_clock (Clock): Clock used to timestamp display onsets.
+ stimuli (StimuliProperties): Attributes used for inquiries.
+ task_bar (TaskBar): Used for task tracking. Ex. 1/100.
+ info (InformationProperties): Attributes to display informational stimuli.
+ preview_config (Optional[PreviewParams]): Parameters for preview functionality.
+ If None, preview is not displayed.
+ trigger_type (str, optional): Defines the calibration trigger type. Defaults to 'image'.
+ space_char (str, optional): Character to use for spaces. Defaults to SPACE_CHAR.
+ full_screen (bool, optional): Whether window is fullscreen. Defaults to False.
"""
self.window = window
self.window_size = self.window.size # [w, h]
@@ -67,14 +93,12 @@ def __init__(
self.logger = logging.getLogger(__name__)
- # Stimuli parameters, these are set on display in order to allow
- # easy updating after definition
+ # Stimuli parameters
self.stimuli_inquiry = stimuli.stim_inquiry
self.stimuli_colors = stimuli.stim_colors
self.stimuli_timing = stimuli.stim_timing
self.stimuli_font = stimuli.stim_font
- # Note: there is a bug in TextBox2 that prevents certain custom fonts from being used. This is to avoid that.
- self.textbox_font = 'Consolas'
+ self.textbox_font = 'Consolas' # Avoid TextBox2 font bug
self.stimuli_height = stimuli.stim_height
self.stimuli_pos = stimuli.stim_pos
self.is_txt_stim = stimuli.is_txt_stim
@@ -97,9 +121,9 @@ def __init__(
self.experiment_clock = experiment_clock
# Callback used on presentation of first stimulus.
- self.first_stim_callback = lambda _sti: None
- self.size_list_sti = [] # TODO force initial size definition
- self.space_char = space_char # TODO remove and force task to define
+ self.first_stim_callback: Callable = lambda _sti: None
+ self.size_list_sti: List[float] = []
+ self.space_char = space_char
self.task_bar = task_bar
@@ -112,10 +136,14 @@ def __init__(
@property
def preview_enabled(self) -> bool:
- """Should the inquiry preview be enabled."""
- return self.preview_params and self.preview_params.show_preview_inquiry
+ """Check if inquiry preview should be enabled.
+
+ Returns:
+ bool: True if preview is enabled and configured, False otherwise.
+ """
+ return bool(self.preview_params and self.preview_params.show_preview_inquiry)
- def draw_static(self):
+ def draw_static(self) -> None:
"""Draw static elements in a stimulus."""
if self.task_bar:
self.task_bar.draw()
@@ -125,13 +153,13 @@ def draw_static(self):
def schedule_to(self,
stimuli: Optional[List[str]] = None,
timing: Optional[List[float]] = None,
- colors: Optional[List[str]] = None):
+ colors: Optional[List[str]] = None) -> None:
"""Schedule stimuli elements (works as a buffer).
Args:
- stimuli(list[string]): list of stimuli text / name
- timing(list[float]): list of timings of stimuli
- colors(list[string]): list of colors
+ stimuli (Optional[List[str]]): List of stimuli text/name.
+ timing (Optional[List[float]]): List of timings of stimuli.
+ colors (Optional[List[str]]): List of colors.
"""
self.stimuli_inquiry = stimuli or []
self.stimuli_timing = timing or []
@@ -139,104 +167,93 @@ def schedule_to(self,
@property
def preview_index(self) -> int:
- """Index within an inquiry at which the inquiry preview should be displayed.
+ """Get index within an inquiry at which the preview should be displayed.
For calibration, we should display it after the target prompt (index = 1).
For copy phrase there is no target prompt so it should display before the
- rest of the inquiry."""
+ rest of the inquiry.
+
+ Returns:
+ int: The index at which to display the preview.
+ """
return 1
def do_inquiry(self) -> List[Tuple[str, float]]:
- """Do inquiry.
+ """Perform an inquiry of flashing letters to achieve RSVP.
- Animates an inquiry of flashing letters to achieve RSVP.
-
-
- RETURNS:
- --------
- timing(list[float]): list of timings of stimuli presented in the inquiry
+ Returns:
+ List[Tuple[str, float]]: List of timings of stimuli presented in the inquiry.
"""
-
- # init an array for timing information
timing: List[Tuple[str, float]] = []
self.preview_accepted = True
if self.first_run:
self._trigger_pulse()
- # generate a inquiry (list of stimuli with meta information)
inquiry = self._generate_inquiry()
- # do the inquiry
for idx, stim_props in enumerate(inquiry):
-
- # If this is the start of an inquiry and a callback registered for first_stim_callback evoke it
if idx == 0 and callable(self.first_stim_callback):
self.first_stim_callback(stim_props['sti'])
- # If previewing the inquiry during calibration, do so after the first stimulus
if self.preview_enabled and idx == self.preview_index:
self.preview_accepted = self.preview_inquiry(timing)
if not self.preview_accepted:
break
- # Reset the timing clock to start presenting
self.window.callOnFlip(
self.trigger_callback.callback,
self.experiment_clock,
stim_props['sti_label'])
- # Draw stimulus for n frames
stim_props['sti'].draw()
self.draw_static()
self.window.flip()
core.wait(stim_props['time_to_present'])
- # append timing information
timing.append(self.trigger_callback.timing)
-
self.trigger_callback.reset()
- # draw in static and flip once more
self.draw_static()
self.window.flip()
return timing
def _trigger_pulse(self) -> None:
- """Trigger Pulse.
+ """Send a calibration trigger pulse.
- This method uses a calibration trigger to determine any functional
- offsets needed for operation with this display. By setting the first_stim_time and searching for the
- same stimuli output to the marker stream, the offsets between these proceses can be reconciled at the
- beginning of an experiment. If drift is detected in your experiment, more frequent pulses and offset
- correction may be required.
+ Uses a calibration trigger to determine any functional offsets needed for
+ operation with this display. Sets first_stim_time and searches for the same
+ stimuli output to the marker stream to reconcile timing offsets.
"""
calibration_time = _calibration_trigger(
self.experiment_clock,
trigger_type=self.trigger_type,
display=self.window)
- # set the first stim time if not present and first_run to False
if not self.first_stim_time:
self.first_stim_time = calibration_time[-1]
self.first_run = False
def preview_inquiry(self, timing: List[Tuple[str, float]]) -> bool:
- """Preview Inquiry.
+ """Preview the inquiry before presentation.
Given an inquiry defined to be presented via do_inquiry(), present the full inquiry
- to the user and allow input on whether the intended letter is present or not before
- going through the rapid serial visual presention.
+ to the user and allow input on whether the intended letter is present or not before
+ going through the rapid serial visual presentation.
+
+ Args:
+ timing (List[Tuple[str, float]]): List to which timing information should be appended.
- Parameters:
- timing - list to which all timing information should be appended.
Returns:
- - A boolean describing whether to present the inquiry (True) or
- generate another (False).
+ bool: True if inquiry should be presented, False if a new one should be generated.
+
+ Raises:
+ AssertionError: If preview is not enabled or button handler is not initialized.
"""
assert self.preview_enabled, "Preview feature not enabled."
assert self.preview_button_handler, "Button handler must be initialized"
+ assert self.preview_params is not None, "Preview parameters must be set"
handler = self.preview_button_handler
self.window.callOnFlip(
@@ -257,19 +274,26 @@ def preview_inquiry(self, timing: List[Tuple[str, float]]) -> bool:
return handler.accept_result()
- def draw_preview(self):
- """Generate and draw the inquiry preview"""
+ def draw_preview(self) -> None:
+ """Generate and draw the inquiry preview."""
content = self._generate_inquiry_preview()
content.draw()
self.draw_static()
self.window.flip()
def _generate_inquiry_preview(self) -> visual.TextBox2:
- """Generate Inquiry Preview.
+ """Generate the inquiry preview box.
- Using the self.stimuli_inquiry list, construct a preview box to display to the user. This method
- assumes the presence of a fixation (+).
+ Using the self.stimuli_inquiry list, construct a preview box to display to the user.
+ Assumes the presence of a fixation (+).
+
+ Returns:
+ visual.TextBox2: The preview text box.
+
+ Raises:
+ AssertionError: If preview parameters are not set.
"""
+ assert self.preview_params is not None, "Preview parameters must be set"
text = ' '.join(self.stimuli_inquiry).split('+ ')[1]
return self._create_stimulus(
@@ -280,10 +304,11 @@ def _generate_inquiry_preview(self) -> visual.TextBox2:
mode='textbox',
align_text='left')
- def _generate_inquiry(self) -> list:
- """Generate inquiry.
+ def _generate_inquiry(self) -> List[Dict[str, Any]]:
+ """Generate stimuli for next RSVP inquiry.
- Generate stimuli for next RSVP inquiry. [A + A, C, Q, D]
+ Returns:
+ List[Dict[str, Any]]: List of stimulus properties for the inquiry.
"""
stim_info = []
for idx, stim in enumerate(self.stimuli_inquiry):
@@ -291,13 +316,9 @@ def _generate_inquiry(self) -> list:
current_stim['time_to_present'] = self.stimuli_timing[idx]
- # check if stimulus needs to use a non-default size
- if self.size_list_sti:
- this_stimuli_size = self.size_list_sti[idx]
- else:
- this_stimuli_size = self.stimuli_height
+ this_stimuli_size = (self.size_list_sti[idx] if self.size_list_sti
+ else self.stimuli_height)
- # Set the Stimuli attrs
if stim.endswith('.png'):
current_stim['sti'] = self._create_stimulus(
mode='image',
@@ -309,32 +330,27 @@ def _generate_inquiry(self) -> list:
current_stim['sti_label'] = path.splitext(
path.basename(stim))[0]
else:
- # text stimulus
- current_stim['sti'] = self._create_stimulus(mode='text', height=this_stimuli_size)
+ current_stim['sti'] = self._create_stimulus(
+ mode='text', height=this_stimuli_size)
txt = stim
- # customize presentation of space char.
current_stim['sti'].text = txt if txt != SPACE_CHAR else self.space_char
current_stim['sti'].color = self.stimuli_colors[idx]
current_stim['sti_label'] = txt
- # test whether the word will be too big for the screen
text_width = current_stim['sti'].boundingBox[0]
if text_width > self.window.size[0]:
screen_info = get_screen_info()
monitor_width = screen_info.width
monitor_height = screen_info.height
text_height = current_stim['sti'].boundingBox[1]
- # If we are in full-screen, text size in Psychopy norm units
- # is monitor width/monitor height
if self.window.size[0] == monitor_width:
new_text_width = monitor_width / monitor_height
else:
- # If not, text width is calculated relative to both
- # monitor size and window size
new_text_width = (
self.window.size[1] / monitor_height) * (
monitor_width / monitor_height)
- new_text_height = (text_height * new_text_width) / text_width
+ new_text_height = (
+ text_height * new_text_width) / text_width
current_stim['sti'].height = new_text_height
stim_info.append(current_stim)
return stim_info
@@ -342,22 +358,22 @@ def _generate_inquiry(self) -> list:
def update_task_bar(self, text: Optional[str] = None) -> None:
"""Update task state.
- Removes letters or appends to the right.
Args:
- text(string): new text for task state
+ text (Optional[str]): New text for task state.
"""
if self.task_bar:
self.task_bar.update(text)
def wait_screen(self, message: str, message_color: str) -> None:
- """Wait Screen.
+ """Display a wait screen with message and optional logo.
Args:
- message(string): message to be displayed while waiting
- message_color(string): color of the message to be displayed
- """
+ message (str): Message to be displayed while waiting.
+ message_color (str): Color of the message to be displayed.
- # Construct the wait message
+ Raises:
+ Exception: If the logo image cannot be loaded.
+ """
wait_message = visual.TextStim(win=self.window,
font=self.stimuli_font,
text=message,
@@ -369,7 +385,6 @@ def wait_screen(self, message: str, message_color: str) -> None:
opacity=1,
depth=-6.0)
- # try adding the BciPy logo to the wait screen
try:
wait_logo = visual.ImageStim(
self.window,
@@ -384,27 +399,41 @@ def wait_screen(self, message: str, message_color: str) -> None:
wait_logo.draw()
except Exception as e:
- self.logger.exception(f'Cannot load logo image from path=[{BCIPY_LOGO_PATH}]')
+ self.logger.exception(
+ f'Cannot load logo image from path=[{BCIPY_LOGO_PATH}]')
raise e
- # Draw and flip the screen.
wait_message.draw()
self.window.flip()
def _create_stimulus(
self,
- height: int,
+ height: float,
mode: str = 'text',
- stimulus='+',
- color='white',
- stimuli_position=None,
- align_text='center',
- units=None,
- wrap_width=None,
- border=False):
- """Create Stimulus.
-
- Returns a TextStim, ImageStim or TextBox object.
+ stimulus: str = '+',
+ color: str = 'white',
+ stimuli_position: Optional[Tuple[float, float]] = None,
+ align_text: str = 'center',
+ units: Optional[str] = None,
+ wrap_width: Optional[float] = None,
+ border: bool = False) -> Union[visual.TextStim, visual.ImageStim, visual.TextBox2]:
+ """Create a stimulus object.
+
+ Args:
+ height (float): Height of the stimulus.
+ mode (str, optional): Type of stimulus ('text', 'image', or 'textbox').
+ Defaults to 'text'.
+ stimulus (str, optional): Content of the stimulus. Defaults to '+'.
+ color (str, optional): Color of the stimulus. Defaults to 'white'.
+ stimuli_position (Optional[Tuple[float, float]], optional): Position of the stimulus.
+ Defaults to None.
+ align_text (str, optional): Text alignment. Defaults to 'center'.
+ units (Optional[str], optional): Units for size/position. Defaults to None.
+ wrap_width (Optional[float], optional): Width for text wrapping. Defaults to None.
+ border (bool, optional): Whether to show border. Defaults to False.
+
+ Returns:
+ Union[visual.TextStim, visual.ImageStim, visual.TextBox2]: The created stimulus object.
"""
if not stimuli_position:
stimuli_position = self.stimuli_pos
@@ -448,3 +477,6 @@ def _create_stimulus(
alignment=align_text,
editable=False,
)
+ else:
+ raise ValueError(
+ f'RSVPDisplay asked to create a stimulus type=[{mode}] that is not supported.')
diff --git a/bcipy/display/paradigm/rsvp/mode/calibration.py b/bcipy/display/paradigm/rsvp/mode/calibration.py
index 9b76a0563..c1efd1c95 100644
--- a/bcipy/display/paradigm/rsvp/mode/calibration.py
+++ b/bcipy/display/paradigm/rsvp/mode/calibration.py
@@ -1,22 +1,59 @@
+"""RSVP calibration display module.
+
+This module provides the RSVP calibration display implementation which handles the visual
+presentation of stimuli during calibration tasks. It extends the base RSVP display with
+calibration-specific functionality.
+"""
+from typing import Optional
+
+from psychopy import core, visual
+
from bcipy.core.symbols import SPACE_CHAR
+from bcipy.display import InformationProperties, StimuliProperties
+from bcipy.display.components.task_bar import TaskBar
+from bcipy.display.main import PreviewParams
from bcipy.display.paradigm.rsvp.display import RSVPDisplay
class CalibrationDisplay(RSVPDisplay):
- """Calibration Display."""
+ """Calibration Display for RSVP paradigm.
+
+ This class extends the RSVPDisplay to provide calibration-specific functionality.
+ It handles the visual presentation of stimuli during calibration tasks, including
+ preview functionality and timing control.
+
+ Attributes:
+ preview_index (int): Index within an inquiry at which the inquiry preview
+ should be displayed. For calibration, this is set to 1 (after target prompt).
+ """
def __init__(self,
- window,
- static_clock,
- experiment_clock,
- stimuli,
- task_bar,
- info,
- trigger_type='image',
- preview_config=None,
- space_char=SPACE_CHAR,
- full_screen=False):
+ window: visual.Window,
+ static_clock: core.StaticPeriod,
+ experiment_clock: core.Clock,
+ stimuli: StimuliProperties,
+ task_bar: TaskBar,
+ info: InformationProperties,
+ trigger_type: str = 'image',
+ preview_config: Optional[PreviewParams] = None,
+ space_char: str = SPACE_CHAR,
+ full_screen: bool = False) -> None:
+ """Initialize the RSVP calibration display.
+ Args:
+ window (visual.Window): PsychoPy window for display.
+ static_clock (core.StaticPeriod): Clock for static timing.
+ experiment_clock (core.Clock): Clock for experiment timing.
+ stimuli (StimuliProperties): Properties for stimulus presentation.
+ task_bar (TaskBar): Task bar component for progress display.
+ info (InformationProperties): Properties for information display.
+ trigger_type (str, optional): Type of trigger to use. Defaults to 'image'.
+ preview_config (Optional[PreviewParams], optional): Configuration for preview
+ functionality. Defaults to None.
+ space_char (str, optional): Character to use for spaces. Defaults to SPACE_CHAR.
+ full_screen (bool, optional): Whether to display in fullscreen mode.
+ Defaults to False.
+ """
super().__init__(window,
static_clock,
experiment_clock,
@@ -33,5 +70,8 @@ def preview_index(self) -> int:
"""Index within an inquiry at which the inquiry preview should be displayed.
For calibration, we should display it after the target prompt (index = 1).
+
+ Returns:
+ int: The index at which to display the preview (1 for calibration).
"""
return 1
diff --git a/bcipy/display/paradigm/rsvp/mode/copy_phrase.py b/bcipy/display/paradigm/rsvp/mode/copy_phrase.py
index 8ce877c21..a516484fc 100644
--- a/bcipy/display/paradigm/rsvp/mode/copy_phrase.py
+++ b/bcipy/display/paradigm/rsvp/mode/copy_phrase.py
@@ -1,42 +1,72 @@
-from psychopy import visual
+"""RSVP copy phrase display module.
+
+This module provides the RSVP copy phrase display implementation which handles the visual
+presentation of stimuli during copy phrase tasks. It extends the base RSVP display with
+copy phrase-specific functionality.
+
+Note:
+ RSVP Tasks are RSVPDisplay objects with different structure. They share
+ the tasks and the essential elements and stimuli. However, layout, length of
+ stimuli list, update procedures and colors are different. Therefore each
+ mode should be separated from each other carefully.
+"""
+
+from typing import Optional
+
+from psychopy import core, visual
from bcipy.core.stimuli import resize_image
from bcipy.core.symbols import SPACE_CHAR
+from bcipy.display import InformationProperties, StimuliProperties
+from bcipy.display.components.task_bar import TaskBar
+from bcipy.display.main import PreviewParams
from bcipy.display.paradigm.rsvp.display import BCIPY_LOGO_PATH, RSVPDisplay
-"""Note:
-
-RSVP Tasks are RSVPDisplay objects with different structure. They share
-the tasks and the essential elements and stimuli. However, layout, length of
-stimuli list, update procedures and colors are different. Therefore each
-mode should be separated from each other carefully.
-Functions:
- update_task_state: update task information of the module
-"""
-
class CopyPhraseDisplay(RSVPDisplay):
- """ Copy Phrase display object of RSVP
+ """Copy Phrase display object of RSVP.
- Custom attributes:
- static_task_text(str): target text for the user to attempt to spell
- static_task_color(str): target text color for the user to attempt to spell
+ This class extends the RSVPDisplay to provide copy phrase-specific functionality.
+ It handles the visual presentation of stimuli during copy phrase tasks, including
+ preview functionality and wait screen display.
+
+ Attributes:
+ starting_spelled_text (str): Initial text that has been spelled.
+ static_task_text (str): Target text for the user to attempt to spell.
+ static_task_color (str): Target text color for the user to attempt to spell.
"""
def __init__(
self,
- window,
- clock,
- experiment_clock,
- stimuli,
- task_bar,
- info,
- starting_spelled_text='',
- trigger_type='image',
- space_char=SPACE_CHAR,
- preview_config=None,
- full_screen=False):
- """ Initializes Copy Phrase Task Objects """
+ window: visual.Window,
+ clock: core.Clock,
+ experiment_clock: core.Clock,
+ stimuli: StimuliProperties,
+ task_bar: TaskBar,
+ info: InformationProperties,
+ starting_spelled_text: str = '',
+ trigger_type: str = 'image',
+ space_char: str = SPACE_CHAR,
+ preview_config: Optional[PreviewParams] = None,
+ full_screen: bool = False) -> None:
+ """Initialize Copy Phrase Task Objects.
+
+ Args:
+ window (visual.Window): PsychoPy window for display.
+ clock (core.Clock): Clock for timing.
+ experiment_clock (core.Clock): Clock for experiment timing.
+ stimuli (StimuliProperties): Properties for stimulus presentation.
+ task_bar (TaskBar): Task bar component for progress display.
+ info (InformationProperties): Properties for information display.
+ starting_spelled_text (str, optional): Initial text that has been spelled.
+ Defaults to ''.
+ trigger_type (str, optional): Type of trigger to use. Defaults to 'image'.
+ space_char (str, optional): Character to use for spaces. Defaults to SPACE_CHAR.
+ preview_config (Optional[PreviewParams], optional): Configuration for preview
+ functionality. Defaults to None.
+ full_screen (bool, optional): Whether to display in fullscreen mode.
+ Defaults to False.
+ """
self.starting_spelled_text = starting_spelled_text
super().__init__(window,
@@ -56,17 +86,22 @@ def preview_index(self) -> int:
For copy phrase there is no target prompt so it should display before
the fixation.
+
+ Returns:
+ int: The index at which to display the preview (0 for copy phrase).
"""
return 0
def wait_screen(self, message: str, message_color: str) -> None:
- """Wait Screen.
+ """Display a wait screen with a message and optional logo.
Args:
- message(string): message to be displayed while waiting
- message_color(string): color of the message to be displayed
- """
+ message (str): Message to be displayed while waiting.
+ message_color (str): Color of the message to be displayed.
+ Raises:
+ Exception: If the logo image cannot be loaded.
+ """
self.draw_static()
# Construct the wait message
@@ -96,7 +131,8 @@ def wait_screen(self, message: str, message_color: str) -> None:
wait_logo.draw()
except Exception as e:
- self.logger.exception(f'Cannot load logo image from path=[{BCIPY_LOGO_PATH}]')
+ self.logger.exception(
+ f'Cannot load logo image from path=[{BCIPY_LOGO_PATH}]')
raise e
# Draw and flip the screen.
diff --git a/bcipy/display/paradigm/vep/display.py b/bcipy/display/paradigm/vep/display.py
index 270f077ca..6f80dfcc3 100644
--- a/bcipy/display/paradigm/vep/display.py
+++ b/bcipy/display/paradigm/vep/display.py
@@ -152,7 +152,8 @@ def check_configuration(self):
def stim_properties(self) -> List[StimProps]:
"""Returns a tuple of (symbol, duration, and color) for each stimuli,
including the target and fixation stim. Stimuli that represent VEP
- boxes will have a list of symbols."""
+ boxes will have a list of symbols.
+ """
stim_num = len(self.stimuli_inquiry)
assert len(self.stimuli_colors
) == stim_num, "Each box should have its own color"
@@ -221,7 +222,8 @@ def prompt_target(self,
target - (symbol, duration, color) tuple
"""
assert isinstance(target.symbol, str), "Target must be a str"
- self.logger.info(f"Target: {target.symbol} at index {target_box_index}")
+ self.logger.info(
+ f"Target: {target.symbol} at index {target_box_index}")
# Show all symbols in the matrix at reduced opacity
for sym in self.symbol_set:
@@ -313,7 +315,8 @@ def animate_inquiry(self, stimuli: List[StimProps]) -> None:
def set_stimuli_colors(self, stim_groups: List[StimProps]) -> None:
"""Update the colors of the stimuli associated with each symbol to
- reflect which box it will be placed in."""
+ reflect which box it will be placed in.
+ """
for group in stim_groups:
for sym in group.symbol:
self.sti[sym].color = group.color
@@ -399,7 +402,8 @@ def add_timing(self, stimuli: str):
"""Add a new timing entry using the stimuli as a label.
Useful as a callback function to register a marker at the time it is
- first displayed."""
+ first displayed.
+ """
self._timing.append(StimTime(stimuli, self.experiment_clock.getTime()))
def reset_timing(self):
@@ -461,8 +465,7 @@ def _trigger_pulse(self) -> None:
def schedule_to(self, stimuli: List[List[Any]], timing: Optional[List[List[float]]]
= None, colors: Optional[List[List[str]]] = None) -> None:
- """Schedule stimuli elements (works as a buffer).
- """
+ """Schedule stimuli elements (works as a buffer)."""
self.stimuli_inquiry = stimuli # type: ignore
assert timing is None or timing == self.stimuli_timing, "Timing values must match pre-configured values"
assert colors is None or colors == self.stimuli_colors, "Colors must match the pre-configured values"
@@ -503,8 +506,7 @@ def _reset_text_boxes(self) -> None:
text_box.borderWidth = self.box_border_width
def _set_inquiry(self, stimuli: List[StimProps]) -> List[visual.TextBox2]:
- """Set the correct inquiry text for each text boxes.
- """
+ """Set the correct inquiry text for each text boxes."""
for i, sti in enumerate(stimuli):
box = self.text_boxes[i]
text = ' '.join(sti.symbol)
diff --git a/bcipy/display/tests/components/test_layout.py b/bcipy/display/tests/components/test_layout.py
index 846c589c1..d0432d90f 100644
--- a/bcipy/display/tests/components/test_layout.py
+++ b/bcipy/display/tests/components/test_layout.py
@@ -189,7 +189,8 @@ def test_scaled_size(self):
msg="Width should be proportional to the window aspect")
self.assertEqual(
- scaled_size(height=0.2, window_size=(800, 500), units='height'), (0.2, 0.2),
+ scaled_size(height=0.2, window_size=(
+ 800, 500), units='height'), (0.2, 0.2),
msg="Width should be the same in 'height' units")
def test_scaled_height(self):
diff --git a/bcipy/exceptions.py b/bcipy/exceptions.py
index 23c525a03..7b6161389 100644
--- a/bcipy/exceptions.py
+++ b/bcipy/exceptions.py
@@ -1,102 +1,171 @@
+"""Custom exceptions for the BciPy package.
+
+This module contains all custom exceptions used throughout the BciPy application.
+Each exception is designed to provide specific error information for different
+components of the system.
+"""
+from typing import Any
+
+
class BciPyCoreException(Exception):
- """BciPy Core Exception.
+ """Base exception class for BciPy-specific errors.
- Thrown when an error occurs specific to BciPy core concepts.
+ Args:
+ message: A descriptive message explaining the error.
+ errors: Optional additional error information.
+
+ Attributes:
+ message: The error message.
+ errors: Additional error details, if any.
"""
- def __init__(self, message, errors=None):
+ def __init__(self, message: str, errors: Any = None) -> None:
super().__init__(message)
self.message = message
self.errors = errors
class SignalException(BciPyCoreException):
- """
- Signal Exception.
+ """Exception raised for errors in the signal processing module.
+
+ This exception is raised when the signal module encounters errors during
+ processing or analysis of signal data.
- Thrown when signal module encounters an error.
+ Args:
+ message: A descriptive message explaining the signal processing error.
+ errors: Optional additional error information.
"""
- def __init__(self, message, errors=None):
+ def __init__(self, message: str, errors: Any = None) -> None:
super().__init__(message)
self.errors = errors
class FieldException(BciPyCoreException):
- """Field Exception.
+ """Exception raised for errors related to experimental fields.
- Thrown when there is an exception relating to experimental fields.
+ This exception is raised when there are issues with field definitions,
+ validation, or processing in experiments.
+
+ Args:
+ message: A descriptive message explaining the field-related error.
+ errors: Optional additional error information.
"""
...
class ExperimentException(BciPyCoreException):
- """Experiment Exception.
+ """Exception raised for errors related to experiment execution.
+
+ This exception is raised when there are issues with experiment setup,
+ execution, or validation.
- Thrown when there is an exception relating to experiments.
+ Args:
+ message: A descriptive message explaining the experiment-related error.
+ errors: Optional additional error information.
"""
...
class UnregisteredExperimentException(ExperimentException):
- """Unregistered Experiment.
+ """Exception raised when attempting to use an unregistered experiment.
- Thrown when experiment is not registered in the provided experiment path.
- """
+ This exception is raised when trying to access or execute an experiment
+ that has not been registered in the provided experiment path.
+ Args:
+ message: A descriptive message explaining which experiment was not found.
+ errors: Optional additional error information.
+ """
...
class UnregisteredFieldException(FieldException):
- """Unregistered Field.
+ """Exception raised when attempting to use an unregistered field.
- Thrown when field is not registered in the provided field path.
- """
+ This exception is raised when trying to access a field that has not been
+ registered in the provided field path.
+ Args:
+ message: A descriptive message explaining which field was not found.
+ errors: Optional additional error information.
+ """
...
class InvalidExperimentException(ExperimentException):
- """Invalid Experiment Exception.
+ """Exception raised when experiment data is in an invalid format.
- Thrown when providing experiment data in the incorrect format.
- """
+ This exception is raised when experiment configuration or data does not
+ meet the required format specifications.
+ Args:
+ message: A descriptive message explaining the format error.
+ errors: Optional additional error information.
+ """
...
class InvalidFieldException(FieldException):
- """Invalid Field Exception.
+ """Exception raised when field data is in an invalid format.
- Thrown when providing field data in the incorrect format.
- """
+ This exception is raised when field configuration or data does not
+ meet the required format specifications.
+ Args:
+ message: A descriptive message explaining the format error.
+ errors: Optional additional error information.
+ """
...
class TaskConfigurationException(BciPyCoreException):
- """Task Configuration Exception.
+ """Exception raised when task configuration is invalid.
+
+ This exception is raised when attempting to run a task with invalid
+ or incompatible configuration settings.
- Thrown when attempting to run a task with invalid configurations"""
+ Args:
+ message: A descriptive message explaining the configuration error.
+ errors: Optional additional error information.
+ """
...
class KenLMInstallationException(BciPyCoreException):
- """KenLM Installation Exception.
+ """Exception raised when KenLM module is not properly installed.
- Thrown when attempting to import kenlm without installing the module"""
+ This exception is raised when attempting to use KenLM functionality
+ without having the required module installed.
+
+ Args:
+ message: A descriptive message explaining the installation issue.
+ errors: Optional additional error information.
+ """
...
class InvalidSymbolSetException(BciPyCoreException):
- """Invalid Symbol Set Exception.
+ """Exception raised when symbol set is not properly configured.
+
+ This exception is raised when attempting to query a language model for
+ predictions without properly configuring the symbol set.
- Thrown when querying a language model for predictions without setting the symbol set."""
+ Args:
+ message: A descriptive message explaining the symbol set error.
+ errors: Optional additional error information.
+ """
...
class LanguageModelNameInUseException(BciPyCoreException):
- """Language Model Name In Use Exception.
+ """Exception raised when attempting to register a duplicate language model.
+
+ This exception is raised when trying to register a language model type
+ with a name that is already in use.
- Thrown when attempting to register a language model type with a duplicate name."""
+ Args:
+ message: A descriptive message explaining the naming conflict.
+ errors: Optional additional error information.
+ """
...
diff --git a/bcipy/feedback/README.md b/bcipy/feedback/README.md
new file mode 100644
index 000000000..59b2761a8
--- /dev/null
+++ b/bcipy/feedback/README.md
@@ -0,0 +1,134 @@
+# BciPy Feedback Module
+
+The feedback module provides a flexible framework for implementing real-time feedback mechanisms in BCI experiments. It supports both visual and auditory feedback, allowing researchers to create customized feedback paradigms for their specific needs.
+
+## Overview
+
+Feedback in BCI systems is crucial for providing users with information about their brain activity in real-time. This module implements a robust feedback system that can be easily extended and customized for different experimental paradigms.
+
+## Core Components
+
+### Base Classes
+
+- `Feedback`: Abstract base class that defines the interface for all feedback mechanisms
+ - Provides common functionality for feedback administration
+ - Defines abstract methods that must be implemented by subclasses
+ - Manages feedback type registration and logging
+
+### Feedback Types
+
+The module supports two main types of feedback:
+
+1. **Visual Feedback** (`VisualFeedback`)
+ - Displays text or image stimuli on screen
+ - Supports customizable positioning, timing, and appearance
+ - Provides precise timing control for stimulus presentation
+ - Features:
+ - Text and image stimulus support
+ - Configurable font, size, and color
+ - Position control
+ - Timing synchronization
+
+2. **Auditory Feedback** (`AuditoryFeedback`)
+ - Plays sound stimuli through the system's audio output
+ - Supports various audio formats and sampling rates
+ - Provides timing control for audio presentation
+ - Features:
+ - Sound playback
+ - Configurable audio parameters
+ - Timing synchronization
+
+## Usage Examples
+
+### Visual Feedback
+
+```python
+from bcipy.feedback.visual.visual_feedback import VisualFeedback
+from psychopy import visual
+from bcipy.helpers.clock import Clock
+
+# Initialize display window
+window = visual.Window(size=[800, 600])
+
+# Configure parameters
+parameters = {
+ 'feedback_font': 'Arial',
+ 'feedback_stim_height': 0.1,
+ 'feedback_pos_x': 0,
+ 'feedback_pos_y': 0,
+ 'feedback_duration': 1.0,
+ 'feedback_color': 'white'
+}
+
+# Create feedback instance
+clock = Clock()
+feedback = VisualFeedback(window, parameters, clock)
+
+# Administer feedback
+timing = feedback.administer("Hello World", StimuliType.TEXT)
+```
+
+### Auditory Feedback
+
+```python
+from bcipy.feedback.sound.auditory_feedback import AuditoryFeedback
+from psychopy import core
+import numpy as np
+
+# Configure parameters
+parameters = {
+ 'sound_buffer_time': 1.0
+}
+
+# Create feedback instance
+clock = core.Clock()
+feedback = AuditoryFeedback(parameters, clock)
+
+# Generate a simple tone
+fs = 44100 # sampling frequency
+t = np.linspace(0, 1, fs) # 1 second duration
+sound = np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
+
+# Administer feedback
+timing = feedback.administer(sound, fs)
+```
+
+## Configuration
+
+Both feedback types can be configured through a parameters dictionary. Common parameters include:
+
+- Timing parameters (duration, intervals)
+- Display parameters (position, size, color)
+- Stimulus-specific parameters (font, audio format)
+
+## Timing Control
+
+The feedback module provides precise timing control through:
+
+- Clock synchronization
+- Timestamp recording
+- Buffer time management
+
+## Extending the Module
+
+To create a new feedback type:
+
+1. Create a new class inheriting from `Feedback`
+2. Implement the required abstract methods:
+ - `configure()`
+ - `administer()`
+3. Register the new feedback type in `FeedbackType` enum
+
+## Best Practices
+
+1. Always use the provided timing mechanisms for synchronization
+2. Handle exceptions appropriately in feedback administration
+3. Clean up resources after feedback presentation
+4. Use appropriate buffer times for smooth presentation
+5. Test feedback timing in your specific experimental setup
+
+## References
+
+- PsychoPy documentation for visual stimulus presentation
+- SoundDevice documentation for audio playback
+- BciPy documentation for integration with other modules
diff --git a/bcipy/feedback/feedback.py b/bcipy/feedback/feedback.py
index e7976a1d3..e21795fb6 100644
--- a/bcipy/feedback/feedback.py
+++ b/bcipy/feedback/feedback.py
@@ -1,26 +1,98 @@
+# mypy: disable-error-code=override
+"""Feedback module.
+
+This module provides the base feedback functionality for BciPy, including abstract
+classes and utilities for creating and managing different types of feedback
+mechanisms (sound, visual, etc.).
+"""
+
import logging
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Any, List
from bcipy.config import SESSION_LOG_FILENAME
-REGISTERED_FEEDBACK_TYPES = ['sound', 'visual']
+class FeedbackType(Enum):
+ """Enumeration of feedback types supported by BciPy (Visual, Audio)."""
+
+ VIS = 'Visual'
+ AUD = 'Audio'
+
+ @classmethod
+ def list(cls) -> List[str]:
+ """Return a list of all available feedback types.
+
+ Returns:
+ List[str]: List of feedback type values
+ """
+ return [feedback_type.value for feedback_type in cls]
+
+
+class StimuliType(Enum):
+ """Enumeration of stimuli types supported by BciPy (Text, Image)."""
+
+ TEXT = 'Text'
+ IMAGE = 'Image'
+
+ @classmethod
+ def list(cls) -> List[str]:
+ """Return a list of all available stimuli types.
-class Feedback:
- """Feedback."""
+ Returns:
+ List[str]: List of stimuli type values
+ """
+ return [stimuli_type.value for stimuli_type in cls]
- def __init__(self, feedback_type):
+
+class Feedback(ABC):
+ """Abstract base class for feedback mechanisms.
+
+ This class defines the interface for different types of feedback in BciPy,
+ such as sound and visual feedback. It provides methods for configuration
+ and administration of feedback.
+
+ Attributes:
+ feedback_type (str): Type of feedback (e.g., 'sound', 'visual').
+ logger (logging.Logger): Logger instance for feedback-related events.
+ """
+
+ def __init__(self, feedback_type: FeedbackType) -> None:
+ """Initialize Feedback.
+
+ Args:
+ feedback_type (str): Type of feedback to be administered.
+ """
super(Feedback, self).__init__()
self.feedback_type = feedback_type
self.logger = logging.getLogger(SESSION_LOG_FILENAME)
- def configure(self):
- raise NotImplementedError()
+ @abstractmethod
+ def administer(self, *args: Any, **kwargs: Any) -> None:
+ """Administer feedback.
+
+ This method should be implemented by subclasses to deliver the actual
+ feedback to the user.
- def administer(self, *args, **kwargs):
- raise NotImplementedError()
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+ """
+ ...
- def _type(self):
+ def _type(self) -> FeedbackType:
+ """Get the feedback type.
+
+ Returns:
+ str: The type of feedback being administered.
+ """
return self.feedback_type
- def _available_modes(self):
- return REGISTERED_FEEDBACK_TYPES
+ def _available_modes(self) -> List[str]:
+ """Get available feedback modes.
+
+ Returns:
+ List[str]: List of registered feedback types.
+ """
+ return FeedbackType.list()
diff --git a/bcipy/feedback/sound/auditory_feedback.py b/bcipy/feedback/sound/auditory_feedback.py
index 2b2a7b609..1d74fd404 100644
--- a/bcipy/feedback/sound/auditory_feedback.py
+++ b/bcipy/feedback/sound/auditory_feedback.py
@@ -1,16 +1,41 @@
+# mypy: disable-error-code=override
+"""Auditory feedback module.
+
+This module provides auditory feedback functionality for BciPy, implementing
+sound-based feedback mechanisms using sounddevice for audio playback.
+"""
+
+from typing import Any, Dict, List, Optional, Union
+
import sounddevice as sd
from psychopy import core
-from bcipy.feedback.feedback import Feedback
+from bcipy.feedback.feedback import Feedback, FeedbackType
class AuditoryFeedback(Feedback):
- """Auditory Feedback."""
+ """Auditory feedback implementation.
+
+ This class provides sound-based feedback functionality, allowing for
+ the playback of audio stimuli with precise timing control.
- def __init__(self, parameters, clock):
+ Attributes:
+ feedback_type (FeedbackType): Type of feedback (AUD).
+ parameters (Dict[str, Any]): Configuration parameters for feedback.
+ sound_buffer_time (float): Buffer time for sound playback.
+ feedback_timestamp_label (str): Label for feedback timing.
+ clock (core.Clock): Clock for timing control.
+ """
+ def __init__(self, parameters: Dict[str, Any], clock: core.Clock) -> None:
+ """Initialize Auditory Feedback.
+
+ Args:
+ parameters (Dict[str, Any]): Configuration parameters for feedback.
+ clock (core.Clock): Clock instance for timing control.
+ """
# Register Feedback Type
- self.feedback_type = 'Auditory Feedback'
+ self.feedback_type = FeedbackType.AUD
super(AuditoryFeedback, self).__init__(self.feedback_type)
@@ -23,7 +48,22 @@ def __init__(self, parameters, clock):
# Clock
self.clock = clock
- def administer(self, sound, fs, assertion=None):
+ def administer(self, sound: Union[List[float], List[List[float]]], fs: int,
+ assertion: Optional[Any] = None) -> List[List[Union[str, float]]]:
+ """Administer auditory feedback.
+
+ Plays the provided sound and records the timing of the feedback.
+
+ Args:
+ sound (Union[List[float], List[List[float]]]): Sound data to play.
+ fs (int): Sampling frequency of the sound.
+ assertion (Optional[Any]): Optional assertion to check before playing.
+ Currently not used.
+
+ Returns:
+ List[List[Union[str, float]]]: List containing timing information
+ in the format [[label, timestamp]].
+ """
timing = []
if assertion:
diff --git a/bcipy/feedback/tests/sound/test_sound_feedback.py b/bcipy/feedback/tests/sound/test_sound_feedback.py
index 5cc418c1d..f4a6bd2b0 100644
--- a/bcipy/feedback/tests/sound/test_sound_feedback.py
+++ b/bcipy/feedback/tests/sound/test_sound_feedback.py
@@ -30,7 +30,7 @@ def tearDown(self):
def test_feedback_type(self):
feedback_type = self.auditory_feedback._type()
- self.assertEqual(feedback_type, 'Auditory Feedback')
+ self.assertEqual(feedback_type.value, 'Audio')
def test_feedback_administer_sound(self):
timestamp = 100
@@ -39,7 +39,8 @@ def test_feedback_administer_sound(self):
self.sound, self.fs)
self.assertTrue(isinstance(resp, list))
- self.assertEqual(resp[0], [self.auditory_feedback.feedback_timestamp_label, timestamp])
+ self.assertEqual(
+ resp[0], [self.auditory_feedback.feedback_timestamp_label, timestamp])
if __name__ == '__main__':
diff --git a/bcipy/feedback/tests/visual/test_visual_feedback.py b/bcipy/feedback/tests/visual/test_visual_feedback.py
index 390896ee4..bd7b1de1f 100644
--- a/bcipy/feedback/tests/visual/test_visual_feedback.py
+++ b/bcipy/feedback/tests/visual/test_visual_feedback.py
@@ -4,7 +4,7 @@
from mockito import (any, mock, unstub, verify, verifyNoUnwantedInteractions,
verifyStubbedInvocationsAreUsed, when)
-from bcipy.feedback.visual.visual_feedback import FeedbackType, VisualFeedback
+from bcipy.feedback.visual.visual_feedback import StimuliType, VisualFeedback
from bcipy.helpers.clock import Clock
@@ -39,7 +39,7 @@ def tearDown(self):
def test_feedback_type(self):
feedback_type = self.visual_feedback._type()
- self.assertEqual(feedback_type, 'Visual Feedback')
+ self.assertEqual(feedback_type.value, 'Visual')
def test_construct_stimulus_image(self):
image_mock = mock()
@@ -52,13 +52,14 @@ def test_construct_stimulus_image(self):
ori=any()
).thenReturn(image_mock)
# mock the resize behavior for the image
- when(self.visual_feedback)._resize_image(any(), any(), any()).thenReturn()
+ when(self.visual_feedback)._resize_image(
+ any(), any(), any()).thenReturn()
response = self.visual_feedback._construct_stimulus(
'test_stim.png',
(0, 0),
None,
- FeedbackType.IMAGE,
+ StimuliType.IMAGE,
)
self.assertEqual(response, image_mock)
@@ -78,7 +79,7 @@ def test_construct_stimulus_text(self):
stimulus,
(0, 0),
None,
- FeedbackType.TEXT,
+ StimuliType.TEXT,
)
self.assertEqual(response, text_mock)
@@ -88,7 +89,8 @@ def test_show_stimuli(self):
when(stimuli_mock).draw().thenReturn(None)
when(self.display).flip().thenReturn(None)
- response = self.visual_feedback._show_stimuli(stimuli_mock) # TODO assertion
+ response = self.visual_feedback._show_stimuli(
+ stimuli_mock) # TODO assertion
verify(stimuli_mock, times=1).draw()
verify(self.display, times=1).flip()
@@ -100,10 +102,12 @@ def test_administer_default(self):
stimulus,
self.visual_feedback.pos_stim,
self.visual_feedback.color,
- FeedbackType.TEXT
+ StimuliType.TEXT
).thenReturn(stimulus)
- when(self.visual_feedback)._show_stimuli(stimulus).thenReturn(timestamp)
- when(psychopy.core).wait(self.visual_feedback.feedback_length).thenReturn()
+ when(self.visual_feedback)._show_stimuli(
+ stimulus).thenReturn(timestamp)
+ when(psychopy.core).wait(
+ self.visual_feedback.feedback_length).thenReturn()
response = self.visual_feedback.administer(stimulus)
expected = [timestamp]
self.assertEqual(response, expected)
diff --git a/bcipy/feedback/visual/visual_feedback.py b/bcipy/feedback/visual/visual_feedback.py
index 0986e8396..ad27fae0c 100644
--- a/bcipy/feedback/visual/visual_feedback.py
+++ b/bcipy/feedback/visual/visual_feedback.py
@@ -1,26 +1,47 @@
-# mypy: disable-error-code="return-value"
-from enum import Enum
-from typing import List, Tuple, Union
+# mypy: disable-error-code="override,return-value"
+"""Visual feedback module.
+
+This module provides visual feedback functionality for BciPy, implementing
+visual-based feedback mechanisms using PsychoPy for stimulus presentation.
+"""
+from typing import Any, Dict, List, Tuple, Union
from psychopy import core, visual
from bcipy.core.stimuli import resize_image
-from bcipy.feedback.feedback import Feedback
+from bcipy.feedback.feedback import Feedback, FeedbackType, StimuliType
from bcipy.helpers.clock import Clock
-class FeedbackType(Enum):
- TEXT = 'TEXT'
- IMAGE = 'IMAGE'
-
-
class VisualFeedback(Feedback):
- """Visual Feedback."""
-
- def __init__(self, display: visual.Window, parameters: dict, clock: Clock) -> None:
-
+ """Visual feedback implementation.
+
+ This class provides visual feedback functionality, allowing for the presentation
+ of text and image stimuli with precise timing control.
+
+ Attributes:
+ feedback_type (FeedbackType): Type of feedback (VIS).
+ display (visual.Window): PsychoPy window for display.
+ parameters (Dict[str, Any]): Configuration parameters for feedback.
+ font_stim (str): Font to use for text stimuli.
+ height_stim (int): Height of stimuli.
+ pos_stim (Tuple[float, float]): Position for stimuli presentation.
+ feedback_length (float): Duration of feedback presentation.
+ color (str): Color of the feedback stimulus.
+ clock (Clock): Clock for timing control.
+ feedback_timestamp_label (str): Label for feedback timing.
+ """
+
+ def __init__(self, display: visual.Window, parameters: Dict[str, Any], clock: Clock) -> None:
+ """Initialize Visual Feedback.
+
+ Args:
+ display (visual.Window): PsychoPy window for display.
+ parameters (Dict[str, Any]): Configuration parameters for feedback.
+ clock (Clock): Clock instance for timing control.
+ """
# Register Feedback Type
- self.feedback_type = 'Visual Feedback'
+ self.feedback_type = FeedbackType.VIS
super(VisualFeedback, self).__init__(self.feedback_type)
@@ -47,13 +68,23 @@ def __init__(self, display: visual.Window, parameters: dict, clock: Clock) -> No
def administer(
self,
stimulus: str,
- stimuli_type=FeedbackType.TEXT) -> List[Tuple[str, float]]:
- """Administer.
+ stimuli_type: StimuliType = StimuliType.TEXT) -> List[Tuple[str, float]]:
+ """Administer visual feedback.
- Administer visual feedback. Timing information from parameters,
- current feedback given by stimulus.
- """
+ Presents visual feedback stimulus and records timing information.
+
+ Args:
+ stimulus (str): The stimulus to present (text or image path).
+ stimuli_type (StimuliType, optional): Type of stimulus to present.
+ Defaults to StimuliType.TEXT.
+
+ Returns:
+ List[Tuple[str, float]]: List containing timing information
+ in the format [(label, timestamp)].
+ Raises:
+ ValueError: If an unsupported stimulus type is provided.
+ """
stim = self._construct_stimulus(
stimulus,
self.pos_stim,
@@ -61,14 +92,21 @@ def administer(
stimuli_type)
time = self._show_stimuli(stim)
-
core.wait(self.feedback_length)
return [time]
def _show_stimuli(self, stimulus: Union[visual.TargetStim, visual.ImageStim]) -> Tuple[str, float]:
+ """Show the stimulus and record timing.
+
+ Args:
+ stimulus (Union[visual.TargetStim, visual.ImageStim]): The stimulus to present.
+
+ Returns:
+ Tuple[str, float]: Timing information in the format (label, timestamp).
+ """
stimulus.draw()
- time = [self.feedback_timestamp_label, self.clock.getTime()] # TODO: use callback for better precision
+ time = [self.feedback_timestamp_label, self.clock.getTime()]
self.display.flip()
return time
@@ -77,6 +115,16 @@ def _resize_image(
stimulus: str,
display_size: Tuple[float, float],
stimuli_height: int) -> Tuple[float, float]:
+ """Resize an image stimulus to fit the display.
+
+ Args:
+ stimulus (str): Path to the image file.
+ display_size (Tuple[float, float]): Size of the display window.
+ stimuli_height (int): Desired height of the stimulus.
+
+ Returns:
+ Tuple[float, float]: New size of the image (width, height).
+ """
return resize_image(
stimulus, display_size, stimuli_height)
@@ -85,8 +133,24 @@ def _construct_stimulus(
stimulus: str,
pos: Tuple[float, float],
fill_color: str,
- stimuli_type: FeedbackType) -> Union[visual.TargetStim, visual.ImageStim]:
- if stimuli_type == FeedbackType.IMAGE:
+ stimuli_type: StimuliType) -> Union[visual.TargetStim, visual.ImageStim]:
+ """Construct a visual stimulus.
+
+ Creates either a text or image stimulus based on the provided type.
+
+ Args:
+ stimulus (str): The stimulus content (text or image path).
+ pos (Tuple[float, float]): Position for the stimulus.
+ fill_color (str): Color for text stimuli.
+ stimuli_type (StimuliType): Type of stimulus to create.
+
+ Returns:
+ Union[visual.TargetStim, visual.ImageStim]: The created stimulus object.
+
+ Raises:
+ ValueError: If an unsupported stimulus type is provided.
+ """
+ if stimuli_type == StimuliType.IMAGE:
image_stim = visual.ImageStim(
win=self.display,
image=stimulus,
@@ -96,7 +160,7 @@ def _construct_stimulus(
image_stim.size = self._resize_image(
stimulus, self.display.size, self.height_stim)
return image_stim
- if stimuli_type == FeedbackType.TEXT:
+ if stimuli_type == StimuliType.TEXT:
return visual.TextStim(
win=self.display,
font=self.font_stim,
@@ -104,3 +168,5 @@ def _construct_stimulus(
height=self.height_stim,
pos=pos,
color=fill_color)
+ raise ValueError(
+ f'VisualFeedback asked to create a stimulus type=[{stimuli_type}] that is not supported.')
diff --git a/bcipy/gui/BCInterface.py b/bcipy/gui/BCInterface.py
index 509157a43..ca097fcd9 100644
--- a/bcipy/gui/BCInterface.py
+++ b/bcipy/gui/BCInterface.py
@@ -1,7 +1,14 @@
+"""BCInterface module.
+
+This module provides the main graphical user interface for BciPy experiments,
+including task execution, parameter management, and offline analysis capabilities.
+"""
+
import logging
import subprocess
import sys
-from typing import List
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple
from bcipy.config import (BCIPY_ROOT, DEFAULT_PARAMETERS_PATH,
PROTOCOL_LOG_FILENAME, STATIC_IMAGES_PATH)
@@ -16,113 +23,191 @@
logger = logging.getLogger(PROTOCOL_LOG_FILENAME)
+@dataclass
+class UIConfig:
+ """Configuration for UI elements.
+
+ Attributes:
+ padding (int): Padding value for UI elements.
+ btn_height (int): Height of buttons.
+ btn_width (int): Width of buttons.
+ font (str): Font family for UI elements.
+ static_font_size (int): Font size for static text.
+ """
+ padding: int = 20
+ btn_height: int = 40
+ btn_width: int = 100
+ font: str = 'Courier New'
+ static_font_size: int = 16
+
+
+@dataclass
+class UserConfig:
+ """Configuration for user-related settings.
+
+ Attributes:
+ max_length (int): Maximum length for user IDs.
+ min_length (int): Minimum length for user IDs.
+ default_text (str): Default text for input fields.
+ """
+ max_length: int = 25
+ min_length: int = 1
+ default_text: str = '...'
+
+
class BCInterface(BCIGui):
"""BCI Interface.
- Main interface for execution of BciPy experiments and tasks. Additionally, quick access to parameter
- editing and loading, and offline analysis execution.
+ Main interface for execution of BciPy experiments and tasks. Provides quick access to parameter
+ editing and loading, and offline analysis execution.
+
+ Attributes:
+ tasks (List[str]): List of available tasks from TaskRegistry.
+ ui_config (UIConfig): UI configuration settings.
+ user_config (UserConfig): User-related configuration settings.
+ parameter_location (str): Path to parameters file.
+ parameters (Dict[str, Any]): Loaded parameters.
+ user_input (Optional[Any]): User input field.
+ experiment_input (Optional[Any]): Experiment input field.
+ task_input (Optional[Any]): Task input field.
+ user (Optional[str]): Selected user.
+ experiment (Optional[str]): Selected experiment.
+ task (Optional[str]): Selected task.
+ users (List[str]): List of available users.
+ disable (bool): Flag to prevent double-clicking.
+ task_start_timeout (int): Timeout for task start.
+ button_timeout (int): Timeout for button actions.
+ autoclose (bool): Whether to auto-close after task completion.
+ alert (bool): Whether to show alerts.
+ user_id_validations (List[Tuple[Callable[[str], bool], str]]): User ID validation rules.
"""
tasks = TaskRegistry().list()
+ ui_config = UIConfig()
+ user_config = UserConfig()
- default_text = '...'
- padding = 20
- btn_height = 40
- btn_width = 100
- max_length = 25
- min_length = 1
- timeout = 3
- font = 'Courier New'
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ """Initialize the BCInterface.
- def __init__(self, *args, **kwargs):
+ Args:
+ *args: Positional arguments passed to BCIGui.
+ **kwargs: Keyword arguments passed to BCIGui.
+ """
super(BCInterface, self).__init__(*args, **kwargs)
self.parameter_location = DEFAULT_PARAMETERS_PATH
+ self.parameters = load_json_parameters(
+ self.parameter_location, value_cast=True)
- self.parameters = load_json_parameters(self.parameter_location,
- value_cast=True)
-
- # These are set in the build_inputs and represent text inputs from the user
+ # Input fields
self.user_input = None
self.experiment_input = None
self.task_input = None
- # These represent the current user, experiment, and task selected in the gui
+ # Selected values
self.user = None
self.experiment = None
self.task = None
+ self.users: List[str] = []
- # user names available in the dropdown menu
- self.users = []
-
- # setup a timer to prevent double clicking in gui
+ # UI state
self.disable = False
self.timer.timeout.connect(self._disable_action)
-
- self.task_start_timeout = self.timeout
- self.button_timeout = self.timeout
-
+ self.task_start_timeout = 3
+ self.button_timeout = 3
self.autoclose = False
self.alert = True
- self.static_font_size = 16
- self.user_id_validations = [
- (invalid_length(min=self.min_length, max=self.max_length),
- f'User ID must contain between {self.min_length} and {self.max_length} alphanumeric characters.'),
+ # Initialize user ID validations
+ self.user_id_validations = self._init_user_validations()
+
+ # Load tasks
+ self.tasks = TaskRegistry().list()
+ if not self.tasks:
+ logger.warning("No tasks found in TaskRegistry")
+
+ def _init_user_validations(self) -> List[Tuple[Callable[[str], bool], str]]:
+ """Initialize user ID validation rules.
+
+ Returns:
+ List[Tuple[Callable[[str], bool], str]]: List of validation rules and error messages.
+ """
+ return [
+ (invalid_length(min=self.user_config.min_length, max=self.user_config.max_length),
+ f'User ID must contain between {self.user_config.min_length} and {self.user_config.max_length} alphanumeric characters.'),
(contains_whitespaces, 'User ID cannot contain white spaces'),
(contains_special_characters, 'User ID cannot contain special characters')
]
def build_buttons(self) -> None:
- """Build Buttons.
-
- Build all buttons necessary for the UI. Define their action on click using the named argument action.
- """
-
+ """Build all buttons for the UI."""
+ self._build_action_buttons()
+ self._build_start_button()
+ self._build_create_experiment_button()
+
+ def _build_action_buttons(self) -> None:
+ """Build the action buttons (Load, Edit, Train)."""
+ # Load button
self.add_button(
- message='Load', position=[self.padding, 450],
- size=[self.btn_width, self.btn_height], background_color='Plum',
+ message='Load',
+ position=[self.ui_config.padding, 450],
+ size=[self.ui_config.btn_width, self.ui_config.btn_height],
+ background_color='Plum',
text_color='black',
- font_family=self.font,
+ font_family=self.ui_config.font,
action=self.select_parameters)
+ # Edit button
self.add_button(
- message='Edit', position=[self.padding + self.btn_width + 10, 450],
- size=[self.btn_width, self.btn_height], background_color='LightCoral',
+ message='Edit',
+ position=[self.ui_config.padding +
+ self.ui_config.btn_width + 10, 450],
+ size=[self.ui_config.btn_width, self.ui_config.btn_height],
+ background_color='LightCoral',
text_color='black',
- font_family=self.font,
+ font_family=self.ui_config.font,
action=self.edit_parameters)
- btn_auc_x = self.padding + (self.btn_width * 2) + 20
+ # Train button
+ btn_auc_x = self.ui_config.padding + \
+ (self.ui_config.btn_width * 2) + 20
self.add_button(
- message='Train', position=(btn_auc_x, 450),
- size=(self.btn_width, self.btn_height), background_color='LightSeaGreen',
+ message='Train',
+ position=(btn_auc_x, 450),
+ size=(self.ui_config.btn_width, self.ui_config.btn_height),
+ background_color='LightSeaGreen',
text_color='black',
- font_family=self.font,
+ font_family=self.ui_config.font,
action=self.offline_analysis)
- btn_start_width = self.btn_width * 2 + 10
- btn_start_x = self.width - (self.padding + btn_start_width)
+ def _build_start_button(self) -> None:
+ """Build the Start Session button."""
+ btn_start_width = self.ui_config.btn_width * 2 + 10
+ btn_start_x = self.width - (self.ui_config.padding + btn_start_width)
self.add_button(
- message='Start Session', position=[btn_start_x, 440],
- size=[btn_start_width, self.btn_height + 10],
+ message='Start Session',
+ position=[btn_start_x, 440],
+ size=[btn_start_width, self.ui_config.btn_height + 10],
background_color='green',
action=self.start_experiment,
text_color='white',
- font_family=self.font)
+ font_family=self.ui_config.font)
+ def _build_create_experiment_button(self) -> None:
+ """Build the Create Experiment button."""
self.add_button(
message='+',
- position=[self.width - self.padding - 200, 260],
- size=[35, self.btn_height - 10],
+ position=[self.width - self.ui_config.padding - 200, 260],
+ size=[35, self.ui_config.btn_height - 10],
background_color='green',
action=self.create_experiment,
text_color='white'
)
def create_experiment(self) -> None:
- """Create Experiment.
+ """Launch the experiment registry.
- Launch the experiment registry which will be used to add new experiments for selection in the GUI.
+ Opens the experiment registry interface for adding new experiments
+ to the GUI selection.
"""
if not self.action_disabled():
subprocess.call(
@@ -131,339 +216,384 @@ def create_experiment(self) -> None:
self.update_experiment_list()
- def update_user_list(self, refresh=True) -> None:
- """Updates the user_input combo box with a list of user ids based on the
- data directory configured in the current parameters."""
- # if refresh is True, then we need to clear the list and add the default text
- if refresh:
- self.user_input.clear()
- self.user_input.addItem(BCInterface.default_text)
-
- # load the users from the data directory and check if they have already been added to the dropdown
- users = load_users(self.parameters['data_save_loc'])
- for user in users:
- if user not in self.users:
- self.user_input.addItem(user)
- self.users.append(user)
-
- def update_experiment_list(self) -> None:
- """Updates the experiment_input combo box with a list of experiments based on the
- data directory configured in the current parameters."""
-
- self.experiment_input.clear()
- self.experiment_input.addItem(BCInterface.default_text)
- self.experiment_input.addItems(self.load_experiments())
-
- def update_task_list(self) -> None:
- """Updates the task_input combo box with a list of tasks ids based on what
- is available in the Task Registry"""
-
- self.task_input.clear()
- self.task_input.addItem(BCInterface.default_text)
- self.task_input.addItems(self.tasks)
-
def build_inputs(self) -> None:
- """Build Inputs.
-
- Build all inputs needed for BCInterface.
- """
+ """Build all input fields for the interface."""
input_x = 170
input_y = 160
- self.user_input = self.add_combobox(
- position=[input_x, input_y],
- size=[280, 25],
- items=[],
- editable=True,
- background_color='white',
- text_color='black')
+ # User input
+ self.user_input = self._build_combobox(
+ position=[input_x, input_y],
+ editable=True)
self.update_user_list()
+ # Experiment input
input_y += 100
- self.experiment_input = self.add_combobox(
+ self.experiment_input = self._build_combobox(
position=[input_x, input_y],
- size=[280, 25],
- items=[],
- editable=False,
- background_color='white',
- text_color='black')
-
+ editable=False)
self.update_experiment_list()
+ # Task input
input_y += 100
- self.task_input = self.add_combobox(
+ self.task_input = self._build_combobox(
position=[input_x, input_y],
+ editable=False)
+ self.update_task_list()
+
+ def _build_combobox(self, position: List[int], editable: bool) -> Any:
+ """Build a combo box with standard styling.
+
+ Args:
+ position (List[int]): Position coordinates [x, y].
+ editable (bool): Whether the combo box is editable.
+
+ Returns:
+ Any: The created combo box.
+ """
+ combo = self.add_combobox(
+ position=position,
size=[280, 25],
- items=[],
- editable=False,
+ items=[self.user_config.default_text],
+ editable=editable,
background_color='white',
text_color='black')
-
- self.update_task_list()
+ return combo
def build_text(self) -> None:
- """Build Text.
-
- Build all static text needed for the UI.
- Positions are relative to the height / width of the UI defined in start_app.
- """
+ """Build all static text elements for the interface."""
+ # Title
self.add_static_textbox(
text='BCInterface',
position=[210, 0],
size=[250, 50],
background_color='black',
text_color='white',
- font_family=self.font,
+ font_family=self.ui_config.font,
font_size=30)
+ # Labels
text_x = 145
- self.add_static_textbox(
- text='User',
- position=[text_x, 105],
- size=[200, 50],
- background_color='black',
- text_color='white',
- font_family=self.font,
- font_size=self.static_font_size)
- self.add_static_textbox(
- text='Experiment',
- position=[text_x, 205],
- size=[300, 50],
- background_color='black',
- font_family=self.font,
- text_color='white',
- font_size=self.static_font_size)
- self.add_static_textbox(
- text='Task',
- position=[text_x, 305],
- size=[300, 50],
- background_color='black',
- text_color='white',
- font_family=self.font,
- font_size=self.static_font_size)
+ labels = [
+ ('User', 105),
+ ('Experiment', 205),
+ ('Task', 305)
+ ]
- def build_images(self) -> None:
- """Build Images.
+ for text, y_pos in labels:
+ self.add_static_textbox(
+ text=text,
+ position=[text_x, y_pos],
+ size=[300, 50],
+ background_color='black',
+ text_color='white',
+ font_family=self.ui_config.font,
+ font_size=self.ui_config.static_font_size)
- Build add images needed for the UI. In this case, the OHSU and NEU logos.
- """
+ def build_images(self) -> None:
+ """Build all image elements for the interface."""
+ # OHSU logo
self.add_image(
- path=f'{STATIC_IMAGES_PATH}/gui/ohsu.png', position=[self.padding, 0], size=100)
+ path=f'{STATIC_IMAGES_PATH}/gui/ohsu.png',
+ position=[self.ui_config.padding, 0],
+ size=100)
+
+ # NEU logo
self.add_image(
- path=f'{STATIC_IMAGES_PATH}/gui/neu.png', position=[self.width - self.padding - 110, 0], size=100)
+ path=f'{STATIC_IMAGES_PATH}/gui/neu.png',
+ position=[self.width - self.ui_config.padding - 110, 0],
+ size=100)
def build_assets(self) -> None:
- """Build Assets.
-
- Define the assets to build in the UI.
- """
+ """Build all UI assets."""
self.build_buttons()
self.build_inputs()
self.build_text()
self.build_images()
+ def update_user_list(self, refresh: bool = True) -> None:
+ """Update the user input combo box with available users.
+
+ Args:
+ refresh (bool): Whether to clear the list before updating.
+ """
+ if refresh and self.user_input:
+ self.user_input.clear()
+ self.user_input.addItem(self.user_config.default_text)
+
+ users = load_users(self.parameters['data_save_loc'])
+ for user in users:
+ if user not in self.users and self.user_input:
+ self.user_input.addItem(user)
+ self.users.append(user)
+
+ def update_experiment_list(self) -> None:
+ """Update the experiment input combo box with available experiments."""
+ if self.experiment_input:
+ self.experiment_input.clear()
+ self.experiment_input.addItem(self.user_config.default_text)
+ experiments = self.load_experiments()
+ if experiments:
+ self.experiment_input.addItems(experiments)
+
+ def update_task_list(self) -> None:
+ """Update the task input combo box with available tasks."""
+ if self.task_input:
+ self.task_input.clear()
+ self.task_input.addItem(self.user_config.default_text)
+ if self.tasks:
+ self.task_input.addItems(self.tasks)
+
def set_parameter_location(self, path: str) -> None:
- """Sets the parameter_location to the given path. Reloads the parameters
- and updates any GUI widgets that are populated based on these params."""
+ """Set the parameter file location and update the interface.
+
+ Args:
+ path (str): Path to the parameters file.
+ """
self.parameter_location = path
- self.parameters = load_json_parameters(self.parameter_location,
- value_cast=True)
+ self.parameters = load_json_parameters(
+ self.parameter_location, value_cast=True)
self.update_user_list(refresh=False)
def select_parameters(self) -> None:
- """Select Parameters.
+ """Open a dialog to select the parameters configuration file."""
+ response = self.get_filename_dialog(
+ message='Select parameters file',
+ file_type='JSON files (*.json)')
+
+ if not response:
+ return
+
+ self.set_parameter_location(response)
+ self._handle_outdated_parameters()
+
+ def _handle_outdated_parameters(self) -> None:
+ """Handle outdated parameter files by prompting for updates."""
+ default_parameters = load_json_parameters(
+ DEFAULT_PARAMETERS_PATH, value_cast=True)
+ if not self.parameters.add_missing_items(default_parameters):
+ return
+
+ save_response = self.throw_alert_message(
+ title='BciPy Alert',
+ message='The selected parameters file is out of date. '
+ 'Would you like to update it with the latest options?',
+ message_type=AlertMessageType.INFO,
+ message_response=AlertMessageResponse.OCE)
+
+ if save_response == AlertResponse.OK.value:
+ self.parameters.save()
- Opens a dialog to select the parameters.json configuration to use.
- """
+ def edit_parameters(self) -> None:
+ """Open the parameter editor."""
+ if self.action_disabled():
+ return
- response = self.get_filename_dialog(message='Select parameters file',
- file_type='JSON files (*.json)')
- if response:
- self.set_parameter_location(response)
- # If outdated, prompt to merge with the current defaults
- default_parameters = load_json_parameters(DEFAULT_PARAMETERS_PATH,
- value_cast=True)
- if self.parameters.add_missing_items(default_parameters):
- save_response = self.throw_alert_message(
- title='BciPy Alert',
- message='The selected parameters file is out of date.'
- 'Would you like to update it with the latest options?',
- message_type=AlertMessageType.INFO,
- message_response=AlertMessageResponse.OCE)
+ if self.parameter_location == DEFAULT_PARAMETERS_PATH:
+ if not self._handle_default_parameters():
+ return
- if save_response == AlertResponse.OK.value:
- self.parameters.save()
+ self._launch_parameter_editor()
- def edit_parameters(self) -> None:
- """Edit Parameters.
+ def _handle_default_parameters(self) -> bool:
+ """Handle editing of default parameters.
- Prompts for a parameters.json file to use. If the default parameters are selected, a copy is used.
- Note that any edits to the parameters file will not be applied to this GUI until the parameters
- are reloaded.
+ Returns:
+ bool: True if should continue with editing, False otherwise.
"""
- if not self.action_disabled():
- if self.parameter_location == DEFAULT_PARAMETERS_PATH:
- # Don't allow the user to overwrite the defaults
- response = self.throw_alert_message(
- title='BciPy Alert',
- message='The default parameters.json cannot be overridden. A copy will be used.',
- message_type=AlertMessageType.INFO,
- message_response=AlertMessageResponse.OCE)
-
- if response == AlertResponse.OK.value:
- self.parameter_location = copy_parameters()
- else:
- return None
+ response = self.throw_alert_message(
+ title='BciPy Alert',
+ message='The default parameters.json cannot be overridden. A copy will be used.',
+ message_type=AlertMessageType.INFO,
+ message_response=AlertMessageResponse.OCE)
+
+ if response == AlertResponse.OK.value:
+ self.parameter_location = copy_parameters()
+ return True
+ return False
- output = subprocess.check_output(
- f'bcipy-params -p "{self.parameter_location}"',
- shell=True)
- if output:
- self.parameter_location = output.decode().strip()
+ def _launch_parameter_editor(self) -> None:
+ """Launch the parameter editor process."""
+ output = subprocess.check_output(
+ f'bcipy-params -p "{self.parameter_location}"',
+ shell=True)
+ if output:
+ self.parameter_location = output.decode().strip()
def check_input(self) -> bool:
- """Check Input.
+ """Check if all required input fields are valid.
- Checks to make sure user has input all required fields.
+ Returns:
+ bool: True if all inputs are valid, False otherwise.
"""
+ self._update_current_selections()
- # Update based on current inputs
- self.user = self.user_input.currentText()
- self.experiment = self.experiment_input.currentText()
- self.task = self.task_input.currentText()
-
- # Check the set values are different than defaults
try:
if not self.check_user_id():
return False
- if self.experiment == BCInterface.default_text and self.task == BCInterface.default_text:
- self.throw_alert_message(
- title='BciPy Alert',
- message='Please select an Experiment or Task for execution',
- message_type=AlertMessageType.INFO,
- message_response=AlertMessageResponse.OTE)
- return False
- if self.experiment != BCInterface.default_text and self.task != BCInterface.default_text:
- self.throw_alert_message(
- title='BciPy Alert',
- message='Please select only an Experiment or Task',
- message_type=AlertMessageType.INFO,
- message_response=AlertMessageResponse.OTE)
+ if not self._validate_experiment_task_selection():
return False
+
except Exception as e:
+ self._show_error_alert(str(e))
+ return False
+
+ return True
+
+ def _update_current_selections(self) -> None:
+ """Update current selections from input fields."""
+ if self.user_input:
+ self.user = self.user_input.currentText()
+ if self.experiment_input:
+ self.experiment = self.experiment_input.currentText()
+ if self.task_input:
+ self.task = self.task_input.currentText()
+
+ def _validate_experiment_task_selection(self) -> bool:
+ """Validate experiment and task selections.
+
+ Returns:
+ bool: True if selections are valid, False otherwise.
+ """
+ if self.experiment == self.user_config.default_text and self.task == self.user_config.default_text:
self.throw_alert_message(
title='BciPy Alert',
- message=f'Error, {e}',
- message_type=AlertMessageType.CRIT,
+ message='Please select an Experiment or Task for execution',
+ message_type=AlertMessageType.INFO,
message_response=AlertMessageResponse.OTE)
return False
+
+ if self.experiment != self.user_config.default_text and self.task != self.user_config.default_text:
+ self.throw_alert_message(
+ title='BciPy Alert',
+ message='Please select only an Experiment or Task',
+ message_type=AlertMessageType.INFO,
+ message_response=AlertMessageResponse.OTE)
+ return False
+
return True
- def check_user_id(self) -> bool:
- """Check User ID
+ def _show_error_alert(self, error_message: str) -> None:
+ """Show an error alert message.
- User ID must meet the following requirements:
+ Args:
+ error_message (str): The error message to display.
+ """
+ self.throw_alert_message(
+ title='BciPy Alert',
+ message=f'Error, {error_message}',
+ message_type=AlertMessageType.CRIT,
+ message_response=AlertMessageResponse.OTE)
- 1. Maximum length of self.max_length alphanumeric characters
- 2. Minimum length of at least self.min_length alphanumeric character
- 3. No special characters
- 4. No spaces
+ def check_user_id(self) -> bool:
+ """Validate the user ID against defined requirements.
+
+ Returns:
+ bool: True if user ID is valid, False otherwise.
"""
- # Check the user id set is different than the default text
- if self.user == BCInterface.default_text:
+ if not self.user or self.user == self.user_config.default_text:
self.throw_alert_message(
title='BciPy Alert',
message='Please input a User ID',
message_type=AlertMessageType.INFO,
message_response=AlertMessageResponse.OTE)
return False
- # Loop over defined user validations and check for error conditions
- for validator in self.user_id_validations:
- (invalid, error_message) = validator
- if invalid(self.user):
+
+ for validator, error_message in self.user_id_validations:
+ if validator(str(self.user)): # Ensure user is treated as string
self.throw_alert_message(
title='BciPy Alert',
message=error_message,
message_type=AlertMessageType.INFO,
- message_response=AlertMessageResponse.OTE
- )
+ message_response=AlertMessageResponse.OTE)
return False
+
return True
def load_experiments(self) -> List[str]:
- """Load experiments
+ """Load available experiments from the default experiment path.
- Loads experiments registered in the DEFAULT_EXPERIMENT_PATH.
+ Returns:
+ List[str]: List of experiment names.
"""
return load_experiments().keys()
def start_experiment(self) -> None:
- """Start Experiment Session.
+ """Start an experiment session."""
+ if not (self.check_input() and not self.action_disabled()):
+ return
+
+ self._show_starting_alert()
+ cmd = self._build_experiment_command()
+ self._execute_experiment(cmd)
+
+ def _show_starting_alert(self) -> None:
+ """Show the task starting alert."""
+ self.throw_alert_message(
+ title='BciPy Alert',
+ message='Task Starting ...',
+ message_type=AlertMessageType.INFO,
+ message_response=AlertMessageResponse.OTE,
+ message_timeout=self.task_start_timeout)
+
+ def _build_experiment_command(self) -> str:
+ """Build the experiment command.
+
+ Returns:
+ str: The command to execute.
+ """
+ if self.task != self.user_config.default_text:
+ cmd = f'bcipy -u "{self.user}" -t "{self.task}" -p "{self.parameter_location}"'
+ else:
+ cmd = f'bcipy -u "{self.user}" -e "{self.experiment}" -p "{self.parameter_location}"'
+
+ if self.alert:
+ cmd += ' -a'
+
+ return cmd
+
+ def _execute_experiment(self, cmd: str) -> None:
+ """Execute the experiment command.
- Using the inputs gathers, check for validity using the check_input method, then launch the experiment using a
- command to bcipy main and subprocess.
+ Args:
+ cmd (str): The command to execute.
"""
- if self.check_input() and not self.action_disabled():
+ output = subprocess.run(cmd, shell=True)
+ if output.returncode != 0:
self.throw_alert_message(
title='BciPy Alert',
- message='Task Starting ...',
- message_type=AlertMessageType.INFO,
- message_response=AlertMessageResponse.OTE,
- message_timeout=self.task_start_timeout)
- if self.task != BCInterface.default_text:
- cmd = (
- f'bcipy '
- f'-u "{self.user}" -t "{self.task}" -p "{self.parameter_location}"'
- )
- else:
- cmd = (
- f'bcipy '
- f'-u "{self.user}" -e "{self.experiment}" -p "{self.parameter_location}"'
- )
- if self.alert:
- cmd += ' -a'
- output = subprocess.run(cmd, shell=True)
- if output.returncode != 0:
- self.throw_alert_message(
- title='BciPy Alert',
- message=f'Error: {output.stderr.decode()}',
- message_type=AlertMessageType.CRIT,
- message_response=AlertMessageResponse.OTE)
+ message=f'Error: {output.stderr.decode()}',
+ message_type=AlertMessageType.CRIT,
+ message_response=AlertMessageResponse.OTE)
- if self.autoclose:
- self.close()
+ if self.autoclose:
+ self.close()
def offline_analysis(self) -> None:
- """Offline Analysis.
-
- Run offline analysis as a script in a new process.
- """
+ """Run offline analysis in a new process."""
if not self.action_disabled():
cmd = f'bcipy-train --alert --p "{self.parameter_location}" -v -s'
subprocess.Popen(cmd, shell=True)
def action_disabled(self) -> bool:
- """Action Disabled.
-
- Method to check whether another action can take place. If not disabled, it will allow the action and
- start a timer that will disable actions until self.timeout (seconds) has occured.
+ """Check if actions are currently disabled.
- Note: the timer is registed with the private method self._disable_action, which when self.timeout has
- been reached, resets self.disable and corresponding timeouts.
+ Returns:
+ bool: True if actions are disabled, False otherwise.
"""
if self.disable:
return True
- else:
- self.disable = True
- # set the update time to every 500ms
- self.timer.start(500)
- return False
+
+ self.disable = True
+ self.timer.start(500)
+ return False
def _disable_action(self) -> bool:
- """Disable Action.
+ """Handle action disabling timer.
- A private method to register with a BCIGui.timer after setting self.button_timeout.
+ Returns:
+ bool: Current disabled state.
"""
if self.button_timeout > 0:
self.disable = True
@@ -472,12 +602,12 @@ def _disable_action(self) -> bool:
self.timer.stop()
self.disable = False
- self.button_timeout = self.timeout
+ self.button_timeout = 3
return self.disable
def start_app() -> None:
- """Start BCIGui."""
+ """Start the BCI interface application."""
bcipy_gui = app(sys.argv)
ex = BCInterface(
title='Brain Computer Interface',
diff --git a/bcipy/gui/README.md b/bcipy/gui/README.md
index bb05a1b32..063e79aff 100644
--- a/bcipy/gui/README.md
+++ b/bcipy/gui/README.md
@@ -1,33 +1,169 @@
-# RSVP Keyboard GUI
-======================================
+# BciPy GUI Module
-This module contains all GUI code used in BciPy. The base window class, BCIGui, is contained in gui_main.py, and contains methods for easily adding widgets to a given window. BCInterface.py launches the main GUI window. There are also interfaces for collecting and editing data (parameters and field data for experiments.)
+The GUI module provides the graphical user interface components for BciPy, enabling users to interact with the BCI system through a user-friendly interface. This module is essential for running experiments, managing parameters, and controlling the BCI workflow.
-## Dependencies
--------------
-This project was written in wxPython version 4.0.4 and PyQt5 5.15.1. We are deprecating the wxPython UIs in future releases.
+## Overview
-## Project structure
----------------
-Name | Description
-------------- | -------------
-BCInterface.py | Defines main GUI window. Selection of user, experiment and task.
-gui_main.py | BCIGui containing methods for adding buttons, images, etc. to GUI window
-parameters/params_form.py | Defines window for setting BCInterface parameters
-experiments/ExperimentRegistry.py | GUI for creating new experiments to select in BCInterface.
-experiments/FieldRegistry.py | GUI for creating new fields for experiment data collection.
-experiments/ExperimentField.py | GUI for collecting a registered experiment's field data.
+The GUI module consists of several key components:
+1. **Main Interface (`BCInterface.py`)**
+ - Primary interface for running BCI experiments
+ - User management and experiment selection
+ - Parameter configuration and task execution
+ - Offline analysis capabilities
-The folder 'bcipy/static/images/gui' contains images for the GUI.
-Parameters loaded by BCInterface parameter definition form can be found in 'bcipy/parameters/parameters.json'.
+2. **Base UI Components (`bciui.py`)**
+ - Core UI building blocks and utilities
+ - Common functionality for all BciPy interfaces
+ - Layout management and styling
+ - Dynamic list and item management
-To run the GUI, do so from the root, as follows:
+3. **Experiment Management**
+ - `ExperimentRegistry.py`: Interface for registering and managing experiments
+ - `ExperimentField.py`: Form for collecting experiment-specific data
+ - Field management and validation
-`python bcipy/gui/BCInterface.py`
+4. **Alert System (`alert.py`)**
+ - User notifications and confirmations
+ - Error handling and system messages
-Contributors:
+5. **Task Transitions (`intertask_gui.py`)**
+ - Progress tracking between tasks
+ - Experiment flow control
+ - User feedback during transitions
-- Tab Memmott
-- Matthew Lawhead
-- Dani Smektala
+## Getting Started
+
+### Running the GUI
+
+To start the BciPy GUI:
+
+```bash
+python bcipy/gui/BCInterface.py
+```
+
+Or using Make (if installed):
+
+```bash
+make bci-gui
+```
+
+### Basic Usage
+
+1. **User Management**
+ - Enter or select a user ID
+ - User IDs must be alphanumeric and meet length requirements
+
+2. **Experiment Selection**
+ - Choose between running a specific task or a complete experiment
+ - Tasks are individual BCI operations (e.g., calibration)
+ - Experiments are predefined sequences of tasks
+
+3. **Parameter Configuration**
+ - Load existing parameter files
+ - Edit parameters through the parameter editor
+ - Save custom configurations
+
+4. **Task Execution**
+ - Start sessions with selected configurations
+ - Monitor progress through the intertask interface
+ - View results and system feedback
+
+5. **Signal Viewer**
+ - Monitor real-time EEG signals during experiments
+ - View and analyze recorded data from previous sessions
+ - Toggle channel visibility and apply montages
+ - Control display duration and filtering options
+ - Pause/resume signal visualization
+ - Support for multiple monitor configurations
+ - See [Signal Viewer Documentation](gui/viewer/README.md) for more details
+
+## Key Features
+
+### Experiment Registry
+
+- Create and manage experiment protocols
+- Define task sequences and parameters
+- Configure experiment-specific fields
+- Save and load experiment configurations
+
+### Parameter Management
+
+- Load default or custom parameter files
+- Edit parameters through a user-friendly interface
+- Save parameter configurations
+- Validate parameter settings
+
+### Task Control
+
+- Start and stop BCI tasks
+- Monitor task progress
+- Handle transitions between tasks
+- Manage experiment flow
+
+### User Interface
+
+- Clean, intuitive design
+- Consistent styling across components
+- Responsive feedback
+- Error handling and notifications
+
+## Development
+
+### Adding New Components
+
+When extending the GUI:
+
+1. Inherit from `BCIUI` for new interfaces
+2. Use the provided UI utilities and components
+3. Follow the established styling guidelines
+4. Implement proper error handling
+5. Add appropriate documentation
+
+### Styling
+
+The GUI uses a consistent styling system:
+
+- CSS-based styling through `bcipy_stylesheet.css`
+- Common UI elements and layouts
+- Responsive design principles
+- Accessibility considerations
+
+## Best Practices
+
+1. **Error Handling**
+ - Use the alert system for user notifications
+ - Validate inputs before processing
+ - Provide clear error messages
+ - Handle edge cases gracefully
+
+2. **User Experience**
+ - Maintain consistent interface behavior
+ - Provide clear feedback for actions
+ - Use appropriate timeouts and delays
+ - Implement proper state management
+
+3. **Performance**
+ - Minimize UI blocking operations
+ - Use appropriate threading for long operations
+ - Optimize resource usage
+ - Handle cleanup properly
+
+## Troubleshooting
+
+Common issues and solutions:
+
+1. **GUI Not Starting**
+ - Check Python and dependency versions
+ - Verify file permissions
+ - Check for conflicting processes
+
+2. **Parameter Issues**
+ - Validate parameter file format
+ - Check file paths and permissions
+ - Verify parameter values
+
+3. **Task Execution Problems**
+ - Check system requirements
+ - Verify device connections
+ - Review error logs
diff --git a/bcipy/gui/alert.py b/bcipy/gui/alert.py
index bc79a167b..23e1b8837 100644
--- a/bcipy/gui/alert.py
+++ b/bcipy/gui/alert.py
@@ -1,6 +1,13 @@
-"""GUI alert messages"""
+"""GUI alert messages module.
+
+This module provides functionality for displaying alert messages and confirmation dialogs
+in the BciPy GUI interface. It includes functions for user interaction through
+standard dialog boxes with customizable options.
+"""
+
# pylint: disable=no-name-in-module
import sys
+from typing import Optional
from PyQt6.QtWidgets import QApplication
@@ -11,12 +18,19 @@
def confirm(message: str) -> bool:
"""Confirmation dialog which allows the user to select between a true and false.
- Parameters
- ----------
- message - alert to display
- Returns
- -------
- users selection : True for selecting Ok, False for Cancel.
+ This function displays a dialog box with OK and Cancel buttons, allowing the user
+ to confirm or cancel an action. The dialog is displayed using the system's native
+ dialog style.
+
+ Args:
+ message (str): The alert message to display in the dialog box.
+
+ Returns:
+ bool: True if the user clicked OK, False if the user clicked Cancel.
+
+ Note:
+ This function creates a new QApplication instance if one doesn't exist,
+ and quits it after the dialog is closed.
"""
app = QApplication(sys.argv).instance()
if not app:
diff --git a/bcipy/gui/bciui.py b/bcipy/gui/bciui.py
index fca09e1d3..c13fca87d 100644
--- a/bcipy/gui/bciui.py
+++ b/bcipy/gui/bciui.py
@@ -1,5 +1,12 @@
+"""BCIUI module.
+
+This module provides the base UI components and utilities for building BciPy's
+graphical user interfaces. It includes base classes for UI elements, dynamic
+lists, and utility functions for common UI operations.
+"""
+
import sys
-from typing import Callable, List, Optional, Type
+from typing import Any, Callable, Dict, List, Optional, Type
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import (QApplication, QHBoxLayout, QLayout, QMessageBox,
@@ -10,29 +17,55 @@
class BCIUI(QWidget):
+ """Base class for BciPy user interfaces.
+
+ This class provides common functionality for all BciPy UI components,
+ including layout management, styling, and utility methods.
+
+ Attributes:
+ contents (QVBoxLayout): Main vertical layout container.
+ center_content_vertically (bool): Whether to center content vertically.
+ """
+
contents: QVBoxLayout
center_content_vertically: bool = False
def __init__(self, title: str = "BCIUI", default_width: int = 500, default_height: int = 600) -> None:
+ """Initialize the BCIUI base class.
+
+ Args:
+ title (str): Window title. Defaults to "BCIUI".
+ default_width (int): Default window width. Defaults to 500.
+ default_height (int): Default window height. Defaults to 600.
+ """
super().__init__()
self.resize(default_width, default_height)
self.setWindowTitle(title)
self.contents = QVBoxLayout()
self.setLayout(self.contents)
- def app(self):
+ def app(self) -> None:
+ """Initialize the application UI.
+
+ This method should be overridden by subclasses to set up their specific UI elements.
+ """
...
def apply_stylesheet(self) -> None:
+ """Apply the BciPy stylesheet to the UI.
+
+ Loads and applies the CSS stylesheet from the BciPy configuration.
+ """
stylesheet_path = f'{BCIPY_ROOT}/gui/bcipy_stylesheet.css' # TODO: move to config
with open(stylesheet_path, "r") as f:
stylesheet = f.read()
self.setStyleSheet(stylesheet)
def display(self) -> None:
- # Push contents to the top of the window
- """
- Display the UI window and apply the stylesheet.
+ """Display the UI window and apply the stylesheet.
+
+ Initializes the UI, applies vertical centering if configured,
+ and shows the window.
"""
self.app()
if not self.center_content_vertically:
@@ -41,12 +74,13 @@ def display(self) -> None:
self.show()
def show_alert(self, alert_text: str) -> int:
- """
- Shows an alert dialog with the specified text.
+ """Show an alert dialog with the specified text.
+
+ Args:
+ alert_text (str): Text to display in the alert dialog.
- PARAMETERS
- ----------
- :param: alert_text: string text to display in the alert dialog.
+ Returns:
+ int: The result code from the message box.
"""
msg = QMessageBox()
msg.setText(alert_text)
@@ -55,6 +89,14 @@ def show_alert(self, alert_text: str) -> int:
@staticmethod
def centered(widget: QWidget) -> QHBoxLayout:
+ """Create a centered horizontal layout for a widget.
+
+ Args:
+ widget (QWidget): Widget to center.
+
+ Returns:
+ QHBoxLayout: Layout with the widget centered horizontally.
+ """
layout = QHBoxLayout()
layout.addStretch()
layout.addWidget(widget)
@@ -63,6 +105,14 @@ def centered(widget: QWidget) -> QHBoxLayout:
@staticmethod
def make_list_scroll_area(widget: QWidget) -> QScrollArea:
+ """Create a scrollable area for a widget.
+
+ Args:
+ widget (QWidget): Widget to make scrollable.
+
+ Returns:
+ QScrollArea: Scrollable area containing the widget.
+ """
scroll_area = QScrollArea()
scroll_area.setWidget(widget)
scroll_area.setWidgetResizable(True)
@@ -72,28 +122,25 @@ def make_list_scroll_area(widget: QWidget) -> QScrollArea:
def make_toggle(
on_button: QPushButton,
off_button: QPushButton,
- on_action: Optional[Callable] = lambda: None,
- off_action: Optional[Callable] = lambda: None,
+ on_action: Callable = lambda: None,
+ off_action: Callable = lambda: None,
) -> None:
- """
- Connects two buttons to toggle between eachother and call passed methods
-
- PARAMETERS
- ----------
- :param: on_button: QPushButton to toggle on
- :param: off_button: QPushButton to toggle off
- :param: on_action: function to call when on_button is clicked
- :param: off_action: function to call when off_button is clicked
+ """Connect two buttons to toggle between each other and call passed methods.
+ Args:
+ on_button (QPushButton): Button to toggle on.
+ off_button (QPushButton): Button to toggle off.
+ on_action Callable: Function to call when on_button is clicked.
+ off_action Callable: Function to call when off_button is clicked.
"""
off_button.hide()
- def toggle_off():
+ def toggle_off() -> None:
on_button.hide()
off_button.show()
off_action()
- def toggle_on():
+ def toggle_on() -> None:
on_button.show()
off_button.hide()
on_action()
@@ -102,127 +149,176 @@ def toggle_on():
off_button.clicked.connect(toggle_on)
def hide(self) -> None:
- """Close the UI window"""
+ """Hide the UI window."""
self.hide()
class SmallButton(QPushButton):
- """A small button with a fixed size"""
+ """A small button with a fixed size.
- def __init__(self, *args, **kwargs):
+ This button is styled with a specific CSS class and fixed size policy.
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ """Initialize the small button.
+
+ Args:
+ *args: Positional arguments passed to QPushButton.
+ **kwargs: Keyword arguments passed to QPushButton.
+ """
super().__init__(*args, **kwargs)
self.setProperty("class", "small-button")
self.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)
class DynamicItem(QWidget):
- """A widget that can be dynamically added and removed from the ui"""
+ """A widget that can be dynamically added and removed from the UI.
+
+ This widget emits a signal when removed and can store arbitrary data.
+
+ Attributes:
+ on_remove (pyqtSignal): Signal emitted when the item is removed.
+ data (Dict[str, Any]): Dictionary for storing arbitrary data.
+ """
on_remove: pyqtSignal = pyqtSignal()
- data: dict = {}
+ data: Dict[str, Any] = {}
def remove(self) -> None:
- """Remove the widget from it's parent DynamicList, removing it from the UI and deleting it"""
+ """Remove the widget from its parent DynamicList.
+
+ Emits the on_remove signal and triggers widget deletion.
+ """
self.on_remove.emit()
class DynamicList(QWidget):
- """A list of QWidgets that can be dynamically updated"""
+ """A list of QWidgets that can be dynamically updated.
+
+ This widget manages a list of DynamicItems that can be added, removed,
+ and reordered.
+
+ Attributes:
+ widgets (List[QWidget]): List of managed widgets.
+ """
widgets: List[QWidget]
- def __init__(self, layout: Optional[QLayout] = None):
+ def __init__(self, layout: Optional[QLayout] = None) -> None:
+ """Initialize the dynamic list.
+
+ Args:
+ layout (Optional[QLayout]): Layout to use. Defaults to QVBoxLayout.
+ """
super().__init__()
if layout is None:
layout = QVBoxLayout()
self.setLayout(layout)
self.widgets = []
- def __len__(self):
+ def __len__(self) -> int:
+ """Get the number of widgets in the list.
+
+ Returns:
+ int: Number of widgets.
+ """
return len(self.widgets)
def add_item(self, item: DynamicItem) -> None:
- """
- Add a DynamicItem to the list.
+ """Add a DynamicItem to the list.
- PARAMETERS
- ----------
- :param: item: DynamicItem to add to the list.
+ Args:
+ item (DynamicItem): Item to add to the list.
"""
self.widgets.append(item)
item.on_remove.connect(lambda: self.remove_item(item))
- self.layout().addWidget(item)
+ layout = self.layout()
+ if layout:
+ layout.addWidget(item)
def move_item(self, item: DynamicItem, new_index: int) -> None:
- """
- Move a DynamicItem to a new index in the list.
+ """Move a DynamicItem to a new index in the list.
- PARAMETERS
- ----------
- :param: item: A reference to the DynamicItem in the list to be moved.
- :param: new_index: int new index to move the item to.
+ Args:
+ item (DynamicItem): Item to move.
+ new_index (int): New index for the item.
+
+ Raises:
+ IndexError: If new_index is out of range.
"""
if new_index < 0 or new_index >= len(self):
raise IndexError(f"Index out of range for length {len(self)}")
self.widgets.pop(self.widgets.index(item))
self.widgets.insert(new_index, item)
- self.layout().removeWidget(item)
- self.layout().insertWidget(new_index, item)
+ layout = self.layout()
+ if layout:
+ layout.removeWidget(item)
+ layout.insertWidget(new_index, item)
def index(self, item: DynamicItem) -> int:
- """
- Get the index of a DynamicItem in the list.
+ """Get the index of a DynamicItem in the list.
- PARAMETERS
- ----------
- :param: item: A reference to the DynamicItem in the list to get the index of.
+ Args:
+ item (DynamicItem): Item to find the index of.
- Returns
- -------
- The index of the item in the list.
+ Returns:
+ int: Index of the item in the list.
"""
return self.widgets.index(item)
def remove_item(self, item: DynamicItem) -> None:
- """
- Remove a DynamicItem from the list.
+ """Remove a DynamicItem from the list.
- PARAMETERS
- ----------
- :param: item: A reference to the DynamicItem to remove from the list
+ Args:
+ item (DynamicItem): Item to remove.
"""
self.widgets.remove(item)
- self.layout().removeWidget(item)
+ layout = self.layout()
+ if layout:
+ layout.removeWidget(item)
item.deleteLater()
def clear(self) -> None:
- """Remove all items from the list"""
+ """Remove all items from the list."""
for widget in self.widgets:
- self.layout().removeWidget(widget)
+ layout = self.layout()
+ if layout:
+ layout.removeWidget(widget)
widget.deleteLater()
self.widgets = []
- def list(self):
- return [widget.data for widget in self.widgets]
+ def list(self) -> List[Dict[str, Any]]:
+ """Get a list of data dictionaries from all items.
- def list_property(self, prop: str):
+ Returns:
+ List[Dict[str, Any]]: List of data dictionaries.
"""
- Get a list of values for a given property of each DynamicItem's data dictionary.
+ return [widget.data for widget in self.widgets]
+
+ def list_property(self, prop: str) -> List[Any]:
+ """Get a list of values for a given property of each DynamicItem's data dictionary.
- PARAMETERS
- ----------
- :param: prop: string property name to get the values of.
+ Args:
+ prop (str): Property name to get values for.
- Returns
- -------
- A list of values for the given property.
+ Returns:
+ List[Any]: List of values for the given property.
"""
return [widget.data[prop] for widget in self.widgets]
-def run_bciui(ui: Type[BCIUI], *args, **kwargs):
- # add app to kwargs
+def run_bciui(ui: Type[BCIUI], *args: Any, **kwargs: Any) -> int:
+ """Run a BCIUI instance.
+
+ Args:
+ ui (Type[BCIUI]): BCIUI class to instantiate.
+ *args: Positional arguments for the UI class.
+ **kwargs: Keyword arguments for the UI class.
+
+ Returns:
+ int: Application exit code.
+ """
app = QApplication(sys.argv).instance()
if not app:
app = QApplication(sys.argv)
diff --git a/bcipy/gui/experiments/ExperimentField.py b/bcipy/gui/experiments/ExperimentField.py
index fbf8772b5..7cde236b0 100644
--- a/bcipy/gui/experiments/ExperimentField.py
+++ b/bcipy/gui/experiments/ExperimentField.py
@@ -188,7 +188,8 @@ def build_save_data(self) -> None:
)
def write_save_data(self) -> None:
- save_experiment_field_data(self.save_data, self.save_path, self.file_name)
+ save_experiment_field_data(
+ self.save_data, self.save_path, self.file_name)
self.throw_alert_message(
title="Success",
message=(
@@ -219,7 +220,8 @@ def throw_alert_message(
if message_response is AlertMessageResponse.OTE:
msg.setStandardButtons(AlertResponse.OK.value)
elif message_response is AlertMessageResponse.OCE:
- msg.setStandardButtons(AlertResponse.OK.value | AlertResponse.CANCEL.value)
+ msg.setStandardButtons(
+ AlertResponse.OK.value | AlertResponse.CANCEL.value)
return msg.exec()
@@ -261,7 +263,8 @@ def initUI(self):
vbox = QVBoxLayout()
self.form_panel = QScrollArea()
- self.form_panel.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
+ self.form_panel.setVerticalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
self.form_panel.setHorizontalScrollBarPolicy(
Qt.ScrollBarPolicy.ScrollBarAlwaysOff
)
@@ -349,7 +352,8 @@ def start_app() -> None:
)
args = parser.parse_args()
- start_experiment_field_collection_gui(args.experiment, args.path, args.filename, args.validate)
+ start_experiment_field_collection_gui(
+ args.experiment, args.path, args.filename, args.validate)
sys.exit()
diff --git a/bcipy/gui/experiments/ExperimentRegistry.py b/bcipy/gui/experiments/ExperimentRegistry.py
index 90c931510..99c122f1f 100644
--- a/bcipy/gui/experiments/ExperimentRegistry.py
+++ b/bcipy/gui/experiments/ExperimentRegistry.py
@@ -201,7 +201,8 @@ def create_experiment(self):
for field in fields
]
task_names = self.protocol_contents.list_property("task_name")
- task_objects = [self.task_registry.get(task_name) for task_name in task_names]
+ task_objects = [self.task_registry.get(
+ task_name) for task_name in task_names]
protocol = serialize_protocol(task_objects)
existing_experiments[experiment_name] = {
@@ -264,7 +265,8 @@ def add_field():
def add_task():
self.protocol_contents.add_item(
- self.make_task_entry(self.experiment_protocol_input.currentText())
+ self.make_task_entry(
+ self.experiment_protocol_input.currentText())
)
self.experiment_protocol_input = QComboBox()
@@ -305,7 +307,8 @@ def add_task():
protocol_scroll_area = QScrollArea()
self.protocol_contents = DynamicList()
- protocol_scroll_area = BCIUI.make_list_scroll_area(self.protocol_contents)
+ protocol_scroll_area = BCIUI.make_list_scroll_area(
+ self.protocol_contents)
label = QLabel("Protocol")
label.setStyleSheet("color: black;")
scroll_area_layout.addWidget(protocol_scroll_area)
diff --git a/bcipy/gui/file_dialog.py b/bcipy/gui/file_dialog.py
index dd39055a4..6e71433db 100644
--- a/bcipy/gui/file_dialog.py
+++ b/bcipy/gui/file_dialog.py
@@ -1,44 +1,83 @@
+"""File dialog module.
+
+This module provides functionality for displaying file and directory selection
+dialogs in the BciPy GUI interface. It includes classes and functions for handling
+file and directory selection with customizable options and filters.
+"""
+
# pylint: disable=no-name-in-module,missing-docstring,too-few-public-methods
import sys
from pathlib import Path
-from typing import Union
+from typing import Optional, Tuple, Union
from PyQt6 import QtGui
+from PyQt6.QtCore import QRect
from PyQt6.QtWidgets import QApplication, QFileDialog, QWidget
from bcipy.exceptions import BciPyCoreException
from bcipy.preferences import preferences
-DEFAULT_FILE_TYPES = "All Files (*)"
+DEFAULT_FILE_TYPES: str = "All Files (*)"
class FileDialog(QWidget):
- """GUI window that prompts the user to select a file."""
+ """GUI window that prompts the user to select a file or directory.
- def __init__(self):
+ This class provides a file dialog interface for selecting files and directories
+ in the BciPy GUI. It supports both file and directory selection with customizable
+ options and filters.
+
+ Attributes:
+ title (str): Window title.
+ window_width (int): Window width in pixels.
+ window_height (int): Window height in pixels.
+ options (QFileDialog.Option): Dialog options.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the file dialog window.
+
+ Sets up the window properties and centers it on the screen.
+ """
super().__init__()
self.title = 'File Dialog'
- self.width = 640
- self.height = 480
+ self.window_width = 640
+ self.window_height = 480
# Center on screen
- self.resize(self.width, self.height)
- frame_geom = self.frameGeometry()
- frame_geom.moveCenter(QtGui.QGuiApplication.primaryScreen().availableGeometry().center())
- self.move(frame_geom.topLeft())
+ self.resize(self.window_width, self.window_height)
+ self._center_window()
# The native dialog may prevent the selection from closing after a
# directory is selected.
self.options = QFileDialog.Option.DontUseNativeDialog
+ def _center_window(self) -> None:
+ """Center the window on the primary screen.
+
+ This method calculates the center position of the primary screen and
+ moves the window to that position.
+ """
+ frame_geom = self.frameGeometry()
+ screen = QtGui.QGuiApplication.primaryScreen()
+ if screen:
+ center_point = screen.availableGeometry().center()
+ frame_geom.moveCenter(center_point)
+ self.move(frame_geom.topLeft())
+
def ask_file(self,
file_types: str = DEFAULT_FILE_TYPES,
directory: str = "",
prompt: str = "Select File") -> str:
- """Opens a file dialog window.
- Returns
- -------
- path or None
+ """Open a file selection dialog window.
+
+ Args:
+ file_types (str, optional): File type filters. Defaults to DEFAULT_FILE_TYPES.
+ directory (str, optional): Initial directory. Defaults to "".
+ prompt (str, optional): Dialog prompt message. Defaults to "Select File".
+
+ Returns:
+ str: Selected file path or empty string if cancelled.
"""
filename, _ = QFileDialog.getOpenFileName(self,
caption=prompt,
@@ -48,11 +87,14 @@ def ask_file(self,
return filename
def ask_directory(self, directory: str = "", prompt: str = "Select Directory") -> str:
- """Opens a dialog window to select a directory.
+ """Open a directory selection dialog window.
+
+ Args:
+ directory (str, optional): Initial directory. Defaults to "".
+ prompt (str, optional): Dialog prompt message. Defaults to "Select Directory".
- Returns
- -------
- path or None
+ Returns:
+ str: Selected directory path or empty string if cancelled.
"""
return QFileDialog.getExistingDirectory(self,
prompt,
@@ -67,18 +109,23 @@ def ask_filename(
strict: bool = False) -> Union[str, BciPyCoreException]:
"""Prompt for a file using a GUI.
- Parameters
- ----------
- - file_types : optional file type filters; Examples: 'Text files (*.txt)'
- or 'Image files (*.jpg *.gif)' or '*.csv;;*.pkl'
- - directory : optional directory
- - prompt : optional prompt message to display to users
- - strict : optional flag to raise an exception if the user cancels the dialog. Default is False.
- If False, an empty string is returned.
-
- Returns
- -------
- path to file or raises an exception if the user cancels the dialog.
+ This function creates a file selection dialog and handles the user's selection.
+ It can optionally raise an exception if no file is selected.
+
+ Args:
+ file_types (str, optional): File type filters. Examples: 'Text files (*.txt)'
+ or 'Image files (*.jpg *.gif)' or '*.csv;;*.pkl'. Defaults to DEFAULT_FILE_TYPES.
+ directory (str, optional): Initial directory. Defaults to "".
+ prompt (str, optional): Dialog prompt message. Defaults to "Select File".
+ strict (bool, optional): If True, raises an exception when no file is selected.
+ If False, returns an empty string. Defaults to False.
+
+ Returns:
+ Union[str, BciPyCoreException]: Selected file path or empty string if cancelled.
+ Raises BciPyCoreException if strict=True and no file is selected.
+
+ Note:
+ Updates the last_directory preference if a file is selected.
"""
app = QApplication(sys.argv)
dialog = FileDialog()
@@ -89,10 +136,7 @@ def ask_filename(
path = Path(filename)
if filename and path.is_file():
preferences.last_directory = str(path.parent)
-
- # Alternatively, we could use `app.closeAllWindows()`
app.quit()
-
return filename
if strict:
@@ -104,18 +148,22 @@ def ask_filename(
def ask_directory(prompt: str = "Select Directory", strict: bool = False) -> Union[str, BciPyCoreException]:
"""Prompt for a directory using a GUI.
- Parameters
- ----------
- prompt : optional prompt message to display to users
- strict : optional flag to raise an exception if the user cancels the dialog. Default is False.
- If False, an empty string is returned.
+ This function creates a directory selection dialog and handles the user's selection.
+ It can optionally raise an exception if no directory is selected.
+
+ Args:
+ prompt (str, optional): Dialog prompt message. Defaults to "Select Directory".
+ strict (bool, optional): If True, raises an exception when no directory is selected.
+ If False, returns an empty string. Defaults to False.
+
+ Returns:
+ Union[str, BciPyCoreException]: Selected directory path or empty string if cancelled.
+ Raises BciPyCoreException if strict=True and no directory is selected.
- Returns
- -------
- path to directory or raises an exception if the user cancels the dialog.
+ Note:
+ Updates the last_directory preference if a directory is selected.
"""
app = QApplication(sys.argv)
-
dialog = FileDialog()
directory = ''
if preferences.last_directory:
@@ -123,10 +171,7 @@ def ask_directory(prompt: str = "Select Directory", strict: bool = False) -> Uni
name = dialog.ask_directory(directory, prompt=prompt)
if name and Path(name).is_dir():
preferences.last_directory = name
-
- # Alternatively, we could use `app.closeAllWindows()`
app.quit()
-
return name
if strict:
diff --git a/bcipy/gui/intertask_gui.py b/bcipy/gui/intertask_gui.py
index 463ca551a..845338b09 100644
--- a/bcipy/gui/intertask_gui.py
+++ b/bcipy/gui/intertask_gui.py
@@ -1,3 +1,9 @@
+"""Intertask GUI module.
+
+This module provides a graphical user interface for managing task transitions
+in BciPy experiments, showing progress and allowing users to control task flow.
+"""
+
import logging
from typing import Callable, List
@@ -11,6 +17,20 @@
class IntertaskGUI(BCIUI):
+ """GUI for managing transitions between tasks in an experiment.
+
+ This class provides a progress interface that shows the current task progress
+ and allows users to proceed to the next task or stop the experiment.
+
+ Attributes:
+ action_name (str): Name of the action type.
+ tasks (List[str]): List of task names in the experiment.
+ current_task_index (int): Index of the current task.
+ next_task_name (str): Name of the next task to be executed.
+ total_tasks (int): Total number of tasks in the experiment.
+ task_progress (int): Current progress through the tasks.
+ callback (Callable): Function to call when stopping tasks.
+ """
action_name = "IntertaskAction"
@@ -18,8 +38,15 @@ def __init__(
self,
next_task_index: int,
tasks: List[str],
- exit_callback: Callable,
- ):
+ exit_callback: Callable[[], None],
+ ) -> None:
+ """Initialize the intertask GUI.
+
+ Args:
+ next_task_index (int): Index of the next task to be executed.
+ tasks (List[str]): List of task names in the experiment.
+ exit_callback (Callable[[], None]): Function to call when stopping tasks.
+ """
self.tasks = tasks
self.current_task_index = next_task_index
self.next_task_name = tasks[self.current_task_index]
@@ -29,7 +56,11 @@ def __init__(
super().__init__("Progress", 800, 150)
self.setProperty("class", "inter-task")
- def app(self):
+ def app(self) -> None:
+ """Initialize and configure the GUI application.
+
+ Sets up the progress display, next task information, and control buttons.
+ """
self.contents.addLayout(BCIUI.centered(QLabel("Experiment Progress")))
progress_container = QHBoxLayout()
@@ -37,7 +68,8 @@ def app(self):
QLabel(f"({self.task_progress}/{self.total_tasks})")
)
self.progress = QProgressBar()
- self.progress.setValue(int(self.task_progress / self.total_tasks * 100))
+ self.progress.setValue(
+ int(self.task_progress / self.total_tasks * 100))
self.progress.setTextVisible(False)
progress_container.addWidget(self.progress)
self.contents.addLayout(progress_container)
@@ -61,22 +93,36 @@ def app(self):
self.next_button.clicked.connect(self.next)
self.stop_button.clicked.connect(self.stop_tasks)
- def stop_tasks(self):
+ def stop_tasks(self) -> None:
+ """Stop the current task execution.
+
+ Calls the exit callback and quits the application.
+ """
# This should exit Task executions
- logger.info(f"Stopping Tasks... user requested. Using callback: {self.callback}")
+ logger.info(
+ f"Stopping Tasks... user requested. Using callback: {self.callback}")
self.callback()
self.quit()
logger.info("Tasks Stopped")
- def next(self):
+ def next(self) -> None:
+ """Proceed to the next task.
+
+ Logs the next task request and quits the application.
+ """
logger.info(f"Next Task=[{self.next_task_name}] requested")
self.quit()
- def quit(self):
- QApplication.instance().quit()
+ def quit(self) -> None:
+ """Quit the application."""
+ instance = QApplication.instance()
+ if instance:
+ instance.quit()
if __name__ == "__main__":
- tasks = ["RSVP Calibration", "IntertaskAction", "Matrix Calibration", "IntertaskAction"]
+ tasks = ["RSVP Calibration", "IntertaskAction",
+ "Matrix Calibration", "IntertaskAction"]
- run_bciui(IntertaskGUI, tasks=tasks, next_task_index=1, exit_callback=lambda: print("Stopping orchestrator"))
+ run_bciui(IntertaskGUI, tasks=tasks, next_task_index=1,
+ exit_callback=lambda: print("Stopping orchestrator"))
diff --git a/bcipy/gui/main.py b/bcipy/gui/main.py
index 31e771475..3cdcfe7b1 100644
--- a/bcipy/gui/main.py
+++ b/bcipy/gui/main.py
@@ -1,3 +1,9 @@
+"""Main GUI module for BciPy.
+
+This module provides the core GUI components and utilities for the BciPy interface,
+including form inputs, alerts, and window management.
+"""
+
# pylint: disable=E0611
import logging
import os
@@ -5,9 +11,10 @@
import sys
from decimal import Decimal
from enum import Enum
-from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
+from typing import (Any, Callable, List, NamedTuple, Optional, Tuple, Union,
+ cast)
-from PyQt6.QtCore import Qt, QTimer, pyqtSlot
+from PyQt6.QtCore import QObject, Qt, QTimer, pyqtSlot
from PyQt6.QtGui import QFont, QPixmap, QShowEvent, QWheelEvent
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox,
QDoubleSpinBox, QFileDialog, QHBoxLayout, QLabel,
@@ -22,7 +29,7 @@ def font(size: int = 16, font_family: str = 'Helvetica') -> QFont:
return QFont(font_family, size, weight=0)
-def invalid_length(min=1, max=25) -> bool:
+def invalid_length(min=1, max=25) -> Callable[[str], bool]:
"""Invalid Length.
Returns a function, which when passed a string will assert whether a string meets min/max conditions.
@@ -35,7 +42,7 @@ def contains_whitespaces(string: str) -> bool:
Checks for the presence of whitespace in a string.
"""
- return re.match(r'^(?=.*[\s])', string)
+ return bool(re.match(r'^(?=.*[\s])', string))
def contains_special_characters(string: str,
@@ -48,7 +55,7 @@ def contains_special_characters(string: str,
return bool(disallowed_chars.search(string))
-def static_text_control(parent,
+def static_text_control(parent: Optional[QWidget],
label: str,
color: str = 'black',
size: int = 16,
@@ -57,7 +64,7 @@ def static_text_control(parent,
creating labels and help components."""
static_text = QLabel(parent)
static_text.setWordWrap(True)
- static_text.setText(label)
+ static_text.setText(str(label))
static_text.setStyleSheet(f'color: {color};')
static_text.setFont(font(size, font_family))
return static_text
@@ -90,12 +97,11 @@ class PushButton(QPushButton):
Custom Button to store unique identifiers which are required for coordinating
events across multiple buttons."""
- id = None
+ id: Optional[int] = None
- def get_id(self):
+ def get_id(self) -> int:
if not self.id:
raise Exception('No ID set on PushButton')
-
return self.id
@@ -116,7 +122,7 @@ def __init__(self, options: List[str], selected_value: str, **kwargs):
self.setText(selected_value)
self.setEditable(True)
- def setText(self, value: str):
+ def setText(self, value: str) -> None:
"""Sets the current index to the given value. If the value is not in the list of
options it will be added."""
if value not in self.options:
@@ -125,7 +131,7 @@ def setText(self, value: str):
self.addItems(self.options)
self.setCurrentIndex(self.options.index(value))
- def text(self):
+ def text(self) -> str:
"""Gets the currentText."""
return self.currentText()
@@ -136,21 +142,20 @@ class MessageBox(QMessageBox):
A custom QMessageBox implementation to provide timeout functionality to QMessageBoxes.
"""
- def __init__(self, *args, **kwargs):
- QMessageBox.__init__(self, *args, **kwargs)
+ def __init__(self, *args: Any, **kwargs: Any):
+ super().__init__(*args, **kwargs)
+ self.timeout = 0.0
+ self.current = 0.0
- self.timeout = 0
- self.current = 0
-
- def showEvent(self, event: QShowEvent) -> None:
+ def showEvent(self, event: Optional[QShowEvent]) -> None:
"""showEvent.
If a timeout greater than zero is defined, set a QTimer to call self.close after the defined timeout.
"""
if self.timeout > 0:
# timeout is in seconds (multiply by 1000 to get ms)
- QTimer().singleShot(self.timeout * 1000, self.close)
- super(MessageBox, self).showEvent(event)
+ QTimer().singleShot(int(self.timeout * 1000.0), self.close)
+ super().showEvent(event)
def setTimeout(self, timeout: float) -> None:
"""setTimeout.
@@ -228,7 +233,7 @@ def __init__(self,
help_size: int = 12,
help_color: str = 'darkgray',
should_display: bool = True):
- super(FormInput, self).__init__()
+ super().__init__()
self.label = label
self.help_tip = help_tip
@@ -245,9 +250,9 @@ def __init__(self,
if not should_display:
self.hide()
- def eventFilter(self, source, event):
+ def eventFilter(self, source: Optional[QObject], event: Any) -> bool:
"""Event filter that suppresses the scroll wheel event."""
- if (event.type() == QWheelEvent and source is self.control):
+ if (isinstance(event, QWheelEvent) and source is self.control):
return True
return False
@@ -255,7 +260,7 @@ def init_label(self) -> QWidget:
"""Initialize the label widget."""
return static_text_control(None, label=self.label, size=16)
- def init_help(self, font_size: int, color: str) -> QWidget:
+ def init_help(self, font_size: int, color: str) -> Optional[QWidget]:
"""Initialize the help text widget."""
if self.help_tip and self.label != self.help_tip:
return static_text_control(None,
@@ -264,14 +269,14 @@ def init_help(self, font_size: int, color: str) -> QWidget:
color=color)
return None
- def init_control(self, value) -> QWidget:
+ def init_control(self, value: Any) -> QWidget:
"""Initialize the form control widget.
Parameter:
---------
value - initial value
"""
# Default is a text input
- return QLineEdit(value)
+ return QLineEdit(str(value))
def init_editable(self, value: Optional[bool]) -> Optional[QWidget]:
"Override. Another checkbox is needed for editable"
@@ -282,7 +287,7 @@ def init_editable(self, value: Optional[bool]) -> Optional[QWidget]:
editable_checkbox.setFont(font(size=12))
return editable_checkbox
- def init_layout(self):
+ def init_layout(self) -> None:
"""Initialize the layout by adding the label, help, and control widgets."""
self.vbox = QVBoxLayout()
if self.label_widget:
@@ -296,7 +301,7 @@ def init_layout(self):
self.vbox.addWidget(self.separator())
self.setLayout(self.vbox)
- def separator(self):
+ def separator(self) -> QWidget:
"""Creates a separator line."""
line = QLabel()
line.setFixedHeight(1)
@@ -307,16 +312,20 @@ def value(self) -> str:
"""Returns the value associated with the form input."""
if self.control:
return self.control.text()
- return None
+ return ""
def is_editable(self) -> bool:
"""Returns whether the input is editable."""
- return self.editable_widget.isChecked()
+ if self.editable_widget:
+ return self.editable_widget.isChecked()
+ return False
@property
def editable(self) -> bool:
"""Returns whether the input is editable."""
- return True if self.editable_widget.isChecked() else False
+ if self.editable_widget:
+ return self.editable_widget.isChecked()
+ return False
def cast_value(self) -> Any:
"""Returns the value associated with the form input, cast to the correct type.
@@ -332,20 +341,20 @@ def matches(self, term: str) -> bool:
self.help_tip and
text in self.help_tip.lower()) or text in self.value().lower()
- def show(self):
+ def show(self) -> None:
"""Show this widget, and all child widgets."""
if self.should_display:
for widget in self.widgets():
if widget:
widget.setVisible(True)
- def hide(self):
+ def hide(self) -> None:
"""Hide this widget, and all child widgets."""
for widget in self.widgets():
if widget:
widget.setVisible(False)
- def widgets(self) -> List[QWidget]:
+ def widgets(self) -> List[Optional[QWidget]]:
"""Returns a list of self and child widgets. List may contain None values."""
return [self.label_widget, self.help_tip_widget, self.control, self]
@@ -360,20 +369,21 @@ class IntegerInput(FormInput):
value - initial value.
"""
- def __init__(self, **kwargs):
- super(IntegerInput, self).__init__(**kwargs)
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
- def init_control(self, value):
+ def init_control(self, value: Any) -> QWidget:
"""Override FormInput to create a spinbox."""
spin_box = QSpinBox()
spin_box.setMinimum(-100000)
spin_box.setMaximum(100000)
- spin_box.wheelEvent = lambda event: None # disable scroll wheel
+ # Disable scroll wheel by overriding the event handler
+ spin_box.wheelEvent = lambda e: None # type: ignore
if value:
spin_box.setValue(int(value))
return spin_box
- def cast_value(self) -> str:
+ def cast_value(self) -> Optional[int]:
"""Override FormInput to return an integer value."""
if self.control:
return int(self.control.text())
@@ -390,13 +400,14 @@ class FloatInputProperties(NamedTuple):
step: float = 0.1
-def float_input_properties(value: str) -> FloatInputProperties:
+def float_input_properties(value: float) -> FloatInputProperties:
"""Given a string representation of a float value, determine suitable
properties for the float component used to input or update this value.
"""
# Determine from the component if there is a reasonable min or max constraint
dec = Decimal(str(value))
_sign, _digits, exponent = dec.as_tuple()
+ exponent = int(exponent)
if exponent > 0:
return FloatInputProperties()
return FloatInputProperties(decimals=abs(exponent), step=10**exponent)
@@ -412,29 +423,30 @@ class FloatInput(FormInput):
value - initial value.
"""
- def __init__(self, **kwargs):
- super(FloatInput, self).__init__(**kwargs)
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
- def init_control(self, value):
+ def init_control(self, value: Any) -> QWidget:
"""Override FormInput to create a spinbox."""
spin_box = QDoubleSpinBox()
# Make a reasonable guess about precision and step size based on the initial value.
- props = float_input_properties(value)
+ props = float_input_properties(float(value))
spin_box.setMinimum(props.min)
spin_box.setMaximum(props.max)
spin_box.setDecimals(props.decimals)
spin_box.setSingleStep(props.step)
spin_box.setValue(float(value))
- spin_box.wheelEvent = lambda event: None # disable scroll wheel
+ # Disable scroll wheel by overriding the event handler
+ spin_box.wheelEvent = lambda e: None # type: ignore
return spin_box
def cast_value(self) -> float:
"""Override FormInput to return as a float."""
if self.control:
return float(self.control.text())
- return None
+ return 0.0
class BoolInput(FormInput):
@@ -447,10 +459,10 @@ class BoolInput(FormInput):
value - initial value.
"""
- def __init__(self, **kwargs):
- super(BoolInput, self).__init__(**kwargs)
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
- def init_control(self, value):
+ def init_control(self, value: Any) -> QWidget:
"""Override to create a checkbox."""
ctl = QCheckBox(f'Enable {self.label}')
ctl.setChecked(value == 'true')
@@ -468,7 +480,7 @@ class RangeInput(FormInput):
from the starting value and list of recommended if provided.
"""
- def init_control(self, value) -> QWidget:
+ def init_control(self, value: Any) -> QWidget:
"""Initialize the form control widget.
Parameter:
@@ -491,14 +503,16 @@ class SelectionInput(FormInput):
help_font_size - font size for the help text.
help_color - color of the help text."""
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any):
assert isinstance(kwargs['options'],
list), f"options are required for {kwargs['label']}"
- super(SelectionInput, self).__init__(**kwargs)
+ super().__init__(**kwargs)
- def init_control(self, value) -> QWidget:
+ def init_control(self, value: Any) -> QWidget:
"""Override to create a Combobox."""
- return ComboBox(self.options, value)
+ if not self.options:
+ raise ValueError(f"options are required for {self.label}")
+ return ComboBox(self.options, str(value))
class TextInput(FormInput):
@@ -512,8 +526,8 @@ class TextInput(FormInput):
help_font_size - font size for the help text.
help_color - color of the help text."""
- def __init__(self, **kwargs):
- super(TextInput, self).__init__(**kwargs)
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
class FileInput(FormInput):
@@ -529,15 +543,15 @@ class FileInput(FormInput):
help_color - color of the help text.
"""
- def __init__(self, **kwargs):
- super(FileInput, self).__init__(**kwargs)
+ def __init__(self, **kwargs: Any):
+ super().__init__(**kwargs)
- def init_control(self, value) -> QWidget:
+ def init_control(self, value: Any) -> QWidget:
"""Override to create either a selection list or text field depending
on whether there are recommended values."""
if isinstance(self.options, list):
- return ComboBox(self.options, value)
- return QLineEdit(value)
+ return ComboBox(self.options, str(value))
+ return QLineEdit(str(value))
def init_button(self) -> QWidget:
"""Creates a Button to initiate the file/directory dialog."""
@@ -577,7 +591,7 @@ def init_layout(self) -> None:
self.vbox.addWidget(self.separator())
self.setLayout(self.vbox)
- def widgets(self) -> List[QWidget]:
+ def widgets(self) -> List[Optional[QWidget]]:
"""Override to include button."""
return super().widgets() + [self.button]
@@ -595,7 +609,7 @@ class DirectoryInput(FileInput):
help_color - color of the help text.
"""
- def prompt_path(self):
+ def prompt_path(self) -> str:
"""Override to prompt for a directory."""
return QFileDialog.getExistingDirectory(caption='Select a path')
@@ -614,7 +628,7 @@ def __init__(self,
label_low: str = "Low:",
label_high="High:",
size=14):
- super(RangeWidget, self).__init__()
+ super().__init__()
self.low, self.high = parse_range(value)
self.input_min = None
@@ -660,7 +674,8 @@ def float_input(self, value: float) -> QWidget:
spin_box.setDecimals(props.decimals)
spin_box.setSingleStep(props.step)
spin_box.setValue(value)
- spin_box.wheelEvent = lambda event: None # disable scroll wheel
+ # Disable scroll wheel by overriding the event handler
+ spin_box.wheelEvent = lambda e: None # type: ignore
return spin_box
def int_input(self, value: int) -> QWidget:
@@ -671,12 +686,13 @@ def int_input(self, value: int) -> QWidget:
-100000 if self.input_min is None else self.input_min)
spin_box.setMaximum(
100000 if self.input_max is None else self.input_max)
- spin_box.wheelEvent = lambda event: None # disable scroll wheel
+ # Disable scroll wheel by overriding the event handler
+ spin_box.wheelEvent = lambda e: None # type: ignore
if value:
spin_box.setValue(value)
return spin_box
- def text(self):
+ def text(self) -> str:
"""Text value"""
low = self.low_input.text() or self.low
high = self.high_input.text() or self.high
@@ -700,8 +716,8 @@ class SearchInput(QWidget):
contents of the text box.
"""
- def __init__(self, on_search, font_size: int = 12):
- super(SearchInput, self).__init__()
+ def __init__(self, on_search: Callable[[str], None], font_size: int = 12):
+ super().__init__()
self.on_search = on_search
@@ -737,21 +753,21 @@ class BCIGui(QWidget):
def __init__(self, title: str, width: int, height: int,
background_color: str):
- super(BCIGui, self).__init__()
+ super().__init__()
logging.basicConfig(level=logging.INFO,
format='%(name)s - %(levelname)s - %(message)s')
self.logger = logging
- self.buttons = []
- self.input_text = []
- self.static_text = []
- self.images = []
- self.comboboxes = []
- self.widgets = []
+ self.buttons: List[PushButton] = []
+ self.input_text: List[QLineEdit] = []
+ self.static_text: List[QLabel] = []
+ self.images: List[QLabel] = []
+ self.comboboxes: List[QComboBox] = []
+ self.widgets: List[QWidget] = []
# set main window properties
self.background_color = background_color
- self.window = QWidget()
+ self._window = QWidget()
self.vbox = QVBoxLayout()
self.setStyleSheet(f'background-color: {self.background_color};')
@@ -760,12 +776,12 @@ def __init__(self, title: str, width: int, height: int,
self.title = title
# determines height/width of window
- self.width = width
- self.height = height
+ self._width = width
+ self._height = height
self.setWindowTitle(self.title)
- self.setFixedWidth(self.width)
- self.setFixedHeight(self.height)
+ self.setFixedWidth(self._width)
+ self.setFixedHeight(self._height)
self.setLayout(self.vbox)
self.create_main_window()
@@ -776,9 +792,9 @@ def create_main_window(self) -> None:
Construct the main window for display of assets.
"""
self.window_layout = QHBoxLayout()
- self.window.setStyleSheet(
+ self._window.setStyleSheet(
f'background-color: {self.background_color};')
- self.window_layout.addWidget(self.window)
+ self.window_layout.addWidget(self._window)
self.vbox.addLayout(self.window_layout)
def add_widget(self, widget: QWidget) -> None:
@@ -848,9 +864,11 @@ def default_button_clicked(self) -> None:
The default action for buttons if none are registed.
"""
- sender = self.sender()
- self.logger.debug(sender.text() + ' was pressed')
- self.logger.debug(sender.get_id())
+ sender = cast(QObject, self.sender())
+ if sender:
+ self.logger.debug(sender.text() + ' was pressed')
+ if isinstance(sender, PushButton):
+ self.logger.debug(sender.get_id())
def add_button(self,
message: str,
@@ -862,7 +880,7 @@ def add_button(self,
font_family: str = 'Times',
action: Optional[Callable] = None) -> PushButton:
"""Add Button."""
- btn = PushButton(message, self.window)
+ btn = PushButton(message, self._window)
btn.id = id
btn.move(position[0], position[1])
btn.resize(size[0], size[1])
@@ -888,7 +906,7 @@ def add_combobox(self,
editable=False) -> QComboBox:
"""Add combobox."""
- combobox = QComboBox(self.window)
+ combobox = QComboBox(self._window)
combobox.move(position[0], position[1])
combobox.resize(size[0], size[1])
@@ -910,7 +928,7 @@ def add_combobox(self,
def add_image(self, path: str, position: list, size: int) -> QLabel:
"""Add Image."""
if os.path.isfile(path):
- labelImage = QLabel(self.window)
+ labelImage = QLabel(self._window)
pixmap = QPixmap(path)
# ensures the new label size will scale the image itself
labelImage.setScaledContents(True)
@@ -943,7 +961,7 @@ def add_static_textbox(self,
wrap_text=False) -> QLabel:
"""Add Static Text."""
- static_text = QLabel(self.window)
+ static_text = QLabel(self._window)
static_text.setText(text)
if wrap_text:
static_text.setWordWrap(True)
@@ -961,7 +979,7 @@ def add_static_textbox(self,
def add_text_input(self, position: list, size: list) -> QLineEdit:
"""Add Text Input."""
- textbox = QLineEdit(self.window)
+ textbox = QLineEdit(self._window)
textbox.move(position[0], position[1])
textbox.resize(size[0], size[1])
@@ -974,7 +992,7 @@ def throw_alert_message(
message: str,
message_type: AlertMessageType = AlertMessageType.INFO,
message_response: AlertMessageResponse = AlertMessageResponse.OTE,
- message_timeout: float = 0) -> MessageBox:
+ message_timeout: float = 0) -> int:
"""Throw Alert Message."""
msg = alert_message(message,
title=title,
@@ -988,7 +1006,7 @@ def get_filename_dialog(self,
file_type: str = 'All Files (*)',
location: str = "") -> str:
"""Get Filename Dialog."""
- file_name, _ = QFileDialog.getOpenFileName(self.window, message,
+ file_name, _ = QFileDialog.getOpenFileName(self._window, message,
location, file_type)
return file_name
@@ -1007,8 +1025,8 @@ def __init__(self,
title: Optional[str] = None):
super().__init__()
- self.height = height
- self.width = width
+ self._height = height
+ self._width = width
self.background_color = background_color
self.setStyleSheet(f'background-color: {self.background_color}')
@@ -1017,11 +1035,13 @@ def __init__(self,
# create the scrollable are
self.frame = QScrollArea()
- self.frame.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
- self.frame.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
+ self.frame.setVerticalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
+ self.frame.setHorizontalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.frame.setWidgetResizable(True)
- self.frame.setFixedWidth(self.width)
- self.setFixedHeight(self.height)
+ self.frame.setFixedWidth(self._width)
+ self.setFixedHeight(self._height)
# if a widget is provided, add it to the scrollable frame
if widget:
@@ -1076,11 +1096,11 @@ class LineItems(QWidget):
def __init__(self, items: List[dict], width: str):
super().__init__()
- self.width = width
+ self._width = int(width)
self.items = items
self.vbox = QVBoxLayout()
- self.setFixedWidth(self.width)
+ self.setFixedWidth(self._width)
# construct the line items as widgets added to the layout
self.construct_line_items()
@@ -1119,13 +1139,12 @@ def construct_line_items(self) -> None:
self.vbox.addLayout(layout)
-def app(args) -> QApplication:
+def app(args: List[str]) -> QApplication:
"""Main app registry.
Passes args from main and initializes the app
"""
-
- bci_app = QApplication(args).instance()
+ bci_app = QApplication.instance()
if not bci_app:
return QApplication(args)
return bci_app
diff --git a/bcipy/gui/parameters/params_form.py b/bcipy/gui/parameters/params_form.py
index 3be18d35e..2b227a5c4 100644
--- a/bcipy/gui/parameters/params_form.py
+++ b/bcipy/gui/parameters/params_form.py
@@ -250,8 +250,10 @@ def __init__(self, json_file: str):
self.layout = QVBoxLayout()
self.changes_area = QScrollArea()
- self.changes_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
- self.changes_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
+ self.changes_area.setVerticalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
+ self.changes_area.setHorizontalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.changes_area.setWidgetResizable(True)
self.changes_area.setWidget(self.change_items)
self.changes_area.setVisible(not self.collapsed)
@@ -334,8 +336,10 @@ def initUI(self):
vbox.addLayout(self.changes_panel)
self.form_panel = QScrollArea()
- self.form_panel.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
- self.form_panel.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
+ self.form_panel.setVerticalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
+ self.form_panel.setHorizontalScrollBarPolicy(
+ Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.form_panel.setWidgetResizable(True)
self.form_panel.setFixedWidth(self.size[0])
self.form_panel.setWidget(self.form)
diff --git a/bcipy/gui/viewer/data_viewer.py b/bcipy/gui/viewer/data_viewer.py
index 83d0e0a21..1c96116d7 100644
--- a/bcipy/gui/viewer/data_viewer.py
+++ b/bcipy/gui/viewer/data_viewer.py
@@ -7,15 +7,13 @@
import matplotlib
import matplotlib.ticker as ticker
import numpy as np
+from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
+from matplotlib.figure import Figure
from PyQt6.QtCore import Qt, QTimer # pylint: disable=no-name-in-module
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QHBoxLayout,
QLabel, QPushButton, QSpinBox, QVBoxLayout,
QWidget)
-matplotlib.use('Qt5Agg')
-from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
-from matplotlib.figure import Figure
-
from bcipy.acquisition.devices import DeviceSpec
from bcipy.acquisition.util import StoppableProcess
from bcipy.core.parameters import DEFAULT_PARAMETERS_PATH, Parameters
@@ -27,6 +25,8 @@
from bcipy.gui.viewer.ring_buffer import RingBuffer
from bcipy.signal.process.transform import Downsample, get_default_transform
+matplotlib.use('Qt5Agg')
+
def filters(
sample_rate_hz: float, parameters: Parameters
@@ -92,14 +92,14 @@ class FixedHeightHBox(QWidget):
def __init__(self, height: int = 30):
super().__init__()
- self.layout = QHBoxLayout()
+ self.layout = QHBoxLayout() # type: ignore
self.layout.setContentsMargins(0, 0, 0, 0)
self.setLayout(self.layout)
self.setFixedHeight(height)
def addWidget(self, widget: QWidget):
"""Add the given widget to the layout"""
- self.layout.addWidget(widget)
+ self.layout.addWidget(widget) # type: ignore
class ChannelControls(QWidget):
@@ -630,7 +630,7 @@ def file_data(path: str
"""
# read metadata
name, freq, channels = settings(path)
- queue = Queue()
+ queue: Queue = Queue()
streamer = FileStreamer(path, queue)
data_source = QueueDataSource(queue)
device_spec = DeviceSpec(name=name, channels=channels, sample_rate=freq)
@@ -639,7 +639,7 @@ def file_data(path: str
return (data_source, device_spec, streamer)
-def main(data_file: str,
+def main(data_file: Optional[str],
seconds: int,
refresh: int,
yscale: int,
@@ -676,7 +676,7 @@ def main(data_file: str,
display_monitor = non_primary_screens[0]
monitor = display_monitor.geometry()
else:
- monitor = app.primaryScreen().geometry()
+ monitor = app.primaryScreen().geometry() # type: ignore
# increase height to 90% of monitor height and preserve aspect ratio.
new_height = int(monitor.height() * 0.9)
diff --git a/bcipy/gui/viewer/ring_buffer.py b/bcipy/gui/viewer/ring_buffer.py
index 3bb5433d4..bac5cea37 100644
--- a/bcipy/gui/viewer/ring_buffer.py
+++ b/bcipy/gui/viewer/ring_buffer.py
@@ -1,35 +1,71 @@
-"""Defines a RingBuffer with a fixed size; when full additional elements
-overwrite the oldest items in the data structure.
+"""Ring buffer implementation for efficient data storage and retrieval.
+
+This module defines a RingBuffer class that implements a fixed-size circular buffer.
+When the buffer is full, new elements overwrite the oldest items in the data structure.
Adapted from Python Cookbook by David Ascher, Alex Martelli
https://www.oreilly.com/library/view/python-cookbook/0596001673/ch05s19.html
"""
+from typing import Any, List, Optional, TypeVar
+
+T = TypeVar('T')
+
class RingBuffer:
- """Data structure with a fixed size; when full additional elements
- overwrite the oldest items in the data structure.
-
- Parameters
- ----------
- size_max - max size of the buffer
- pre_allocated - whether to create all values on initialization
- empty_value - if pre_allocated, empty_value is used to set the
- values with no data.
+ """A fixed-size circular buffer implementation.
+
+ This class implements a ring buffer (circular buffer) with a fixed maximum size.
+ When the buffer is full, new elements overwrite the oldest items in the data structure.
+
+ Attributes:
+ empty_value (Any): Value used to represent empty slots in the buffer.
+ max (int): Maximum size of the buffer.
+ data (List[Any]): Internal storage for buffer elements.
+ cur (int): Current position in the buffer.
+ full (bool): Whether the buffer is full.
+ pre_allocated (bool): Whether the buffer was pre-allocated with empty values.
+
+ Args:
+ size_max (int): Maximum size of the buffer.
+ pre_allocated (bool, optional): Whether to create all values on initialization.
+ Defaults to False.
+ empty_value (Any, optional): If pre_allocated, this value is used to set the
+ values with no data. Defaults to None.
+
+ Raises:
+ AssertionError: If size_max is not greater than 0.
"""
def __init__(self, size_max: int, pre_allocated: bool = False,
- empty_value=None):
+ empty_value: Optional[Any] = None) -> None:
+ """Initialize the ring buffer.
+
+ Args:
+ size_max (int): Maximum size of the buffer.
+ pre_allocated (bool, optional): Whether to create all values on initialization.
+ Defaults to False.
+ empty_value (Any, optional): If pre_allocated, this value is used to set the
+ values with no data. Defaults to None.
+
+ Raises:
+ AssertionError: If size_max is not greater than 0.
+ """
assert size_max > 0
self.empty_value = empty_value
self.max = size_max
- self.data = [empty_value] * size_max if pre_allocated else []
+ self.data: List[Any] = [empty_value] * \
+ size_max if pre_allocated else []
self.cur = 0
self.full = False
self.pre_allocated = pre_allocated
- def append(self, item):
- """Add an element to the buffer, overwriting if full."""
+ def append(self, item: Any) -> None:
+ """Add an element to the buffer, overwriting if full.
+
+ Args:
+ item (Any): The item to add to the buffer.
+ """
if self.full or self.pre_allocated:
# overwrite
self.data[self.cur] = item
@@ -39,11 +75,20 @@ def append(self, item):
self.full = self.cur == self.max - 1
self.cur = (self.cur + 1) % self.max
- def get(self):
- """Return a list of elements from the oldest to the newest."""
+ def get(self) -> List[Any]:
+ """Return a list of elements from the oldest to the newest.
+
+ Returns:
+ List[Any]: List of elements in chronological order.
+ """
if self.full:
return self.data[self.cur:] + self.data[:self.cur]
return self.data
- def is_empty(self):
+ def is_empty(self) -> bool:
+ """Check if the buffer is empty.
+
+ Returns:
+ bool: True if the buffer is empty or contains only empty values.
+ """
return len(self.data) == 0 or self.data[0] == self.empty_value
diff --git a/bcipy/helpers/acquisition.py b/bcipy/helpers/acquisition.py
index 774bee5d7..5851d3faf 100644
--- a/bcipy/helpers/acquisition.py
+++ b/bcipy/helpers/acquisition.py
@@ -66,7 +66,8 @@ def init_acquisition(
device_spec = init_device(content_type, device_name, status)
raw_data_name = raw_data_filename(device_spec)
- client = init_lsl_client(parameters, device_spec, save_folder, raw_data_name)
+ client = init_lsl_client(
+ parameters, device_spec, save_folder, raw_data_name)
manager.add_client(client)
manager.start_acquisition()
@@ -115,7 +116,8 @@ def init_device(content_type: str,
spec = preconfigured_device(device_name, strict=True)
else:
discovered_spec = discover_device_spec(content_type)
- configured_spec = preconfigured_device(discovered_spec.name, strict=False)
+ configured_spec = preconfigured_device(
+ discovered_spec.name, strict=False)
spec = configured_spec or discovered_spec
if status_override is not None:
spec.status = status_override
@@ -335,7 +337,8 @@ def is_stream_type_active(stream_type: StreamType) -> bool:
"""Check if the provided stream type is active.
A stream type's status, if provided, will be used to make the determinition.
- If missing, the status of a matching pre-configured device will be used."""
+ If missing, the status of a matching pre-configured device will be used.
+ """
content_type, device_name, status = stream_type
if status:
return status == DeviceStatus.ACTIVE
diff --git a/bcipy/helpers/copy_phrase_wrapper.py b/bcipy/helpers/copy_phrase_wrapper.py
index c41e4dafa..d30bf4ae6 100644
--- a/bcipy/helpers/copy_phrase_wrapper.py
+++ b/bcipy/helpers/copy_phrase_wrapper.py
@@ -216,7 +216,8 @@ def initialize_series(self) -> Tuple[bool, InquirySchedule]:
prob_dist = self.conjugator.update_and_fuse(
{EvidenceType.LM: np.array(prior)})
except Exception as fusion_error:
- log.exception(f'Error fusing language model evidence!: {fusion_error}')
+ log.exception(
+ f'Error fusing language model evidence!: {fusion_error}')
raise BciPyCoreException(fusion_error) from fusion_error
# Get decision maker to give us back some decisions and stimuli
diff --git a/bcipy/helpers/demo/demo_visualization.py b/bcipy/helpers/demo/demo_visualization.py
index 149e0498d..e4b264479 100644
--- a/bcipy/helpers/demo/demo_visualization.py
+++ b/bcipy/helpers/demo/demo_visualization.py
@@ -86,7 +86,8 @@
trigger_targetness, trigger_timing, trigger_symbols = trigger_decoder(
offset=static_offset,
trigger_path=f"{path}/{TRIGGER_FILENAME}",
- exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION],
+ exclusion=[TriggerType.PREVIEW,
+ TriggerType.EVENT, TriggerType.FIXATION],
)
labels = [0 if label == 'nontarget' else 1 for label in trigger_targetness]
diff --git a/bcipy/helpers/language_model.py b/bcipy/helpers/language_model.py
index 810058176..d91154817 100644
--- a/bcipy/helpers/language_model.py
+++ b/bcipy/helpers/language_model.py
@@ -8,17 +8,17 @@
from bcipy.core.symbols import alphabet
from bcipy.exceptions import LanguageModelNameInUseException
from bcipy.language.main import LanguageModel
+from bcipy.language.model.causal import CausalLanguageModelAdapter
+from bcipy.language.model.mixture import MixtureLanguageModelAdapter
+from bcipy.language.model.ngram import NGramLanguageModelAdapter
+from bcipy.language.model.oracle import OracleLanguageModel
+from bcipy.language.model.uniform import UniformLanguageModel
# pylint: disable=unused-import
# flake8: noqa
"""Only imported models will be included in language_models_by_name"""
# flake8: noqa
-from bcipy.language.model.causal import CausalLanguageModelAdapter
-from bcipy.language.model.mixture import MixtureLanguageModelAdapter
-from bcipy.language.model.ngram import NGramLanguageModelAdapter
-from bcipy.language.model.oracle import OracleLanguageModel
-from bcipy.language.model.uniform import UniformLanguageModel
VALID_LANGUAGE_MODELS: Dict[str, Callable[[], LanguageModel]] = {
"CAUSAL": CausalLanguageModelAdapter,
diff --git a/bcipy/helpers/offset.py b/bcipy/helpers/offset.py
index e36a68783..c45ab8404 100644
--- a/bcipy/helpers/offset.py
+++ b/bcipy/helpers/offset.py
@@ -154,7 +154,8 @@ def calculate_latency(raw_data: RawData,
# if it's not normal, take the median
if p_value < 0.05:
- print(f'Non-normal distribution of diffs. p-value=[{p_value}] Consider using median for static offset.')
+ print(
+ f'Non-normal distribution of diffs. p-value=[{p_value}] Consider using median for static offset.')
recommended_static = abs(np.median(diffs))
print(
f'System recommended static offset median=[{recommended_static}]')
@@ -188,7 +189,8 @@ def calculate_latency(raw_data: RawData,
linewidth=0.5,
color='cyan')
- ax.plot(trg_box_x, trg_box_y, label=f'{diode_channel} (photodiode triggers)')
+ ax.plot(trg_box_x, trg_box_y,
+ label=f'{diode_channel} (photodiode triggers)')
# Add labels for TRGs
first_trg = trigger_diodes_timestamps[0]
@@ -260,7 +262,8 @@ def sample_rate_diffs(raw_data: RawData) -> Tuple[int, float]:
# get the count of all the samples and calculate the time recorded from the raw_data
sample_time = raw_data.dataframe.shape[0] / raw_data.sample_rate
- print(f'LSL Timestamp Sample Count: {lsl_sample_diff} EEG Sample Count: {sample_time}')
+ print(
+ f'LSL Timestamp Sample Count: {lsl_sample_diff} EEG Sample Count: {sample_time}')
return lsl_sample_diff, sample_time
@@ -335,10 +338,12 @@ def extract_data_latency_calculation(
args = parser.parse_args()
data_path = args.data_path
if not data_path:
- data_path = ask_directory(prompt="Please select a BciPy time test directory..", strict=True)
+ data_path = ask_directory(
+ prompt="Please select a BciPy time test directory..", strict=True)
# grab the stim length from the data directory parameters
- stim_length = load_json_parameters(f'{data_path}/{DEFAULT_PARAMETERS_FILENAME}', value_cast=True)['stim_length']
+ stim_length = load_json_parameters(
+ f'{data_path}/{DEFAULT_PARAMETERS_FILENAME}', value_cast=True)['stim_length']
raw_data, triggers, static_offset = extract_data_latency_calculation(
data_path,
diff --git a/bcipy/helpers/task.py b/bcipy/helpers/task.py
index 8e6dbd8f4..e7ae64c84 100644
--- a/bcipy/helpers/task.py
+++ b/bcipy/helpers/task.py
@@ -156,7 +156,8 @@ def get_data_for_decision(inquiry_timing: List[Tuple[str, float]],
for text, timing in inquiry_timing]
# Define the amount of data required for any processing to occur.
- data_limit = round((time2 - time1 + poststim) * daq.device_spec.sample_rate)
+ data_limit = round((time2 - time1 + poststim) *
+ daq.device_spec.sample_rate)
log.info(f'Need {data_limit} records for processing')
# Query for raw data
@@ -252,7 +253,8 @@ def relative_triggers(inquiry_timing: List[Tuple[str, float]],
def _float_val(col: Any) -> float:
"""Convert marker data to float values so we can put them in a
typed np.array. The marker column has type float if it has a 0.0
- value, and would only have type str for a marker value."""
+ value, and would only have type str for a marker value.
+ """
if isinstance(col, str):
return 1.0
return float(col)
@@ -350,7 +352,8 @@ def pause_on_wait_screen(window, message, color) -> bool:
elapsed_seconds = time.time() - pause_start
if elapsed_seconds >= MAX_PAUSE_SECONDS:
- log.info(f"Pause exceeded the allowed time ({MAX_PAUSE_SECONDS} seconds). Ending task.")
+ log.info(
+ f"Pause exceeded the allowed time ({MAX_PAUSE_SECONDS} seconds). Ending task.")
return False
if keys[0] == 'escape':
return False
diff --git a/bcipy/helpers/tests/test_offset.py b/bcipy/helpers/tests/test_offset.py
index f9a3d5f6e..4ed0944ae 100644
--- a/bcipy/helpers/tests/test_offset.py
+++ b/bcipy/helpers/tests/test_offset.py
@@ -38,7 +38,8 @@ def setUp(self) -> None:
remove_pre_fixation=False,
exclusion=[TriggerType.FIXATION],
device_type='EEG')
- self.triggers = list(zip(trigger_label, trigger_targetness, trigger_time))
+ self.triggers = list(
+ zip(trigger_label, trigger_targetness, trigger_time))
self.diode_channel = 'TRG'
self.stim_number = 10
@@ -62,7 +63,8 @@ def test_sample_to_seconds_throws_error_on_zero_args(self):
def test_extract_data_latency_calculation(self):
static_offset = 0.0
recommend = False
- resp = extract_data_latency_calculation(self.tmp_dir, recommend, static_offset)
+ resp = extract_data_latency_calculation(
+ self.tmp_dir, recommend, static_offset)
self.assertIsInstance(resp[0], RawData)
self.assertEqual(resp[1], self.triggers)
self.assertEqual(resp[2], static_offset)
@@ -70,7 +72,8 @@ def test_extract_data_latency_calculation(self):
def test_extract_data_latency_calculation_resets_static_offset_on_recommend(self):
static_offset = 1.0
recommend = True
- resp = extract_data_latency_calculation(self.tmp_dir, recommend, static_offset)
+ resp = extract_data_latency_calculation(
+ self.tmp_dir, recommend, static_offset)
self.assertIsInstance(resp[0], RawData)
self.assertEqual(resp[1], self.triggers)
self.assertNotEqual(resp[2], static_offset)
diff --git a/bcipy/helpers/tests/test_system_utils.py b/bcipy/helpers/tests/test_system_utils.py
index 445e47afb..322d086e2 100644
--- a/bcipy/helpers/tests/test_system_utils.py
+++ b/bcipy/helpers/tests/test_system_utils.py
@@ -15,12 +15,14 @@ class TestSystemUtilsAlerts(unittest.TestCase):
def test_is_connected_true(self):
"""Test that a computer connected to the internet returns True."""
mock_conn = mock()
- when(socket).create_connection(address=any, timeout=any).thenReturn(mock_conn)
+ when(socket).create_connection(
+ address=any, timeout=any).thenReturn(mock_conn)
self.assertTrue(is_connected())
def test_is_connected_false(self):
"""Test that a computer not connected to the internet returns False."""
- when(socket).create_connection(address=any, timeout=any).thenRaise(OSError)
+ when(socket).create_connection(
+ address=any, timeout=any).thenRaise(OSError)
self.assertFalse(is_connected())
def test_is_battery_powered_true(self):
diff --git a/bcipy/helpers/tests/test_visualization.py b/bcipy/helpers/tests/test_visualization.py
index aa7ff7a86..5cb137e39 100644
--- a/bcipy/helpers/tests/test_visualization.py
+++ b/bcipy/helpers/tests/test_visualization.py
@@ -17,7 +17,8 @@ class TestVisualizeSessionData(unittest.TestCase):
def setUp(self):
self.tmp_dir = str(Path(tempfile.mkdtemp()))
- self.parameters = load_json_parameters(DEFAULT_PARAMETERS_PATH, value_cast=True)
+ self.parameters = load_json_parameters(
+ DEFAULT_PARAMETERS_PATH, value_cast=True)
self.raw_data_mock = mock()
self.raw_data_mock.daq_type = 'DSI-24'
self.raw_data_mock.sample_rate = 300
@@ -33,7 +34,8 @@ def test_visualize_session_data(self):
trigger_label_mock = ['target', 'nontarget']
show = True
when(RawData).load(any()).thenReturn(self.raw_data_mock)
- when(visualization).analysis_channels(any(), any()).thenReturn(self.channel_map_mock)
+ when(visualization).analysis_channels(
+ any(), any()).thenReturn(self.channel_map_mock)
when(visualization).trigger_decoder(
offset=any(),
trigger_path=any(),
@@ -82,7 +84,8 @@ def test_visualize_session_data_with_no_valid_targets(self):
trigger_label_mock = ['nontarget', 'nontarget']
show = False
when(RawData).load(any()).thenReturn(self.raw_data_mock)
- when(visualization).analysis_channels(any(), any()).thenReturn(self.channel_map_mock)
+ when(visualization).analysis_channels(
+ any(), any()).thenReturn(self.channel_map_mock)
when(visualization).trigger_decoder(
offset=any(),
trigger_path=any(),
@@ -100,7 +103,8 @@ def test_visualize_session_data_with_no_valid_nontargets(self):
trigger_label_mock = ['target', 'target']
show = False
when(RawData).load(any()).thenReturn(self.raw_data_mock)
- when(visualization).analysis_channels(any(), any()).thenReturn(self.channel_map_mock)
+ when(visualization).analysis_channels(
+ any(), any()).thenReturn(self.channel_map_mock)
when(visualization).trigger_decoder(
offset=any(),
trigger_path=any(),
@@ -118,7 +122,8 @@ def test_visualize_session_data_with_invalid_timing(self):
trigger_label_mock = ['target', 'nontarget']
show = False
when(RawData).load(any()).thenReturn(self.raw_data_mock)
- when(visualization).analysis_channels(any(), any()).thenReturn(self.channel_map_mock)
+ when(visualization).analysis_channels(
+ any(), any()).thenReturn(self.channel_map_mock)
when(visualization).trigger_decoder(
offset=any(),
trigger_path=any(),
diff --git a/bcipy/helpers/utils.py b/bcipy/helpers/utils.py
index c423dde4a..98cbd8254 100644
--- a/bcipy/helpers/utils.py
+++ b/bcipy/helpers/utils.py
@@ -1,6 +1,7 @@
# mypy: disable-error-code="import-untyped"
"""Utilities for system information and general functionality that may be
-shared across modules."""
+shared across modules.
+"""
import importlib
import logging
import os
@@ -23,6 +24,7 @@
class ScreenInfo(NamedTuple):
+ """Screen information including width, height, and refresh rate."""
width: int
height: int
rate: float
@@ -52,7 +54,8 @@ def is_battery_powered() -> bool:
-------
True if the computer is currently running on battery power. This can impact
the performance of hardware (ex. GPU) needed for BciPy operation by entering
- power saving operations."""
+ power saving operations.
+ """
return psutil.sensors_battery(
) and not psutil.sensors_battery().power_plugged
@@ -79,7 +82,8 @@ def git_dir() -> Optional[str]:
"""Git Directory.
Returns the root directory with the .git folder. If this source code
- was not checked out from scm, answers None."""
+ was not checked out from scm, answers None.
+ """
# Relative to current file; may need to be modified if method is moved.
git_root = Path(os.path.abspath(__file__)).parent.parent.parent
@@ -143,7 +147,7 @@ def get_screen_info(stim_screen: Optional[int] = None) -> ScreenInfo:
Returns
-------
ScreenInfo(width, height, rate)
- """
+ """
if stim_screen:
screen = pyglet.canvas.get_display().get_screens()[stim_screen]
else:
@@ -285,7 +289,6 @@ def log_to_stdout():
"""Set logging to stdout. Useful for demo scripts.
https://stackoverflow.com/questions/14058453/making-python-loggers-output-all-messages-to-stdout-in-addition-to-log-file
"""
-
root = logging.getLogger()
root.setLevel(logging.DEBUG)
@@ -309,7 +312,8 @@ def wrap(*args, **kwargs):
time1 = time.perf_counter()
response = func(*args, **kwargs)
time2 = time.perf_counter()
- log.info('{:s} method took {:0.4f}s to execute'.format(func.__name__, (time2 - time1)))
+ log.info('{:s} method took {:0.4f}s to execute'.format(
+ func.__name__, (time2 - time1)))
return response
return wrap
diff --git a/bcipy/helpers/validate.py b/bcipy/helpers/validate.py
index 7c8184a89..be71f37e7 100644
--- a/bcipy/helpers/validate.py
+++ b/bcipy/helpers/validate.py
@@ -76,7 +76,8 @@ def _validate_experiment_fields(experiment_fields, fields):
try:
fields[field_name]
except KeyError:
- raise UnregisteredFieldException(f'Field [{field}] is not registered in [{fields}]')
+ raise UnregisteredFieldException(
+ f'Field [{field}] is not registered in [{fields}]')
try:
field[field_name]['required']
@@ -94,7 +95,8 @@ def validate_field_data_written(path: str, file_name: str) -> bool:
experiment_data_path = f'{path}/{file_name}'
if os.path.isfile(experiment_data_path):
return True
- raise InvalidFieldException(f'Experimental field data expected at path=[{experiment_data_path}] but not found.')
+ raise InvalidFieldException(
+ f'Experimental field data expected at path=[{experiment_data_path}] but not found.')
def validate_experiments(experiments, fields) -> bool:
diff --git a/bcipy/helpers/visualization.py b/bcipy/helpers/visualization.py
index b4bd72294..96ff982f4 100644
--- a/bcipy/helpers/visualization.py
+++ b/bcipy/helpers/visualization.py
@@ -60,7 +60,7 @@ def visualize_erp(
plot_topomaps: Optional[bool] = True,
show: Optional[bool] = False,
save_path: Optional[str] = None) -> List[Figure]:
- """ Visualize ERP.
+ """Visualize ERP.
Generates a comparative ERP figure following a task execution. Given a set of trailed data,
and labels describing two classes (Nontarget=0 and Target=1), they are plotted and may be saved
@@ -90,8 +90,10 @@ def visualize_erp(
else:
baseline = None
- mne_data = convert_to_mne(raw_data, channel_map=channel_map, transform=transform)
- epochs = mne_epochs(mne_data, trial_length, trigger_timing, trigger_labels, baseline=baseline)
+ mne_data = convert_to_mne(
+ raw_data, channel_map=channel_map, transform=transform)
+ epochs = mne_epochs(mne_data, trial_length, trigger_timing,
+ trigger_labels, baseline=baseline)
# *Note* We assume, as described above, two trigger classes are defined for use in trigger_labels
# (Nontarget=0 and Target=1). This will map into two corresponding MNE epochs whose indexing starts at 1.
# Therefore, epochs['1'] == Nontarget and epochs['2'] == Target.
@@ -102,10 +104,12 @@ def visualize_erp(
if plot_topomaps:
# make a list of equally spaced times to plot topomaps using the time window
# defined in the task parameters
- times = [round(trial_window[0] + i * (trial_window[1] - trial_window[0]) / 5, 1) for i in range(7)]
+ times = [round(trial_window[0] + i * (trial_window[1] -
+ trial_window[0]) / 5, 1) for i in range(7)]
# clip any times that are out of bounds of the time window or zero
- times = [time for time in times if trial_window[0] <= time <= trial_window[1] and time != 0]
+ times = [time for time in times if trial_window[0]
+ <= time <= trial_window[1] and time != 0]
figs.extend(visualize_joint_average(
epochs, ['Non-Target', 'Target'],
@@ -150,7 +154,6 @@ def visualize_gaze(
heatmap: Optional[bool]: Whether or not to plot the heatmap. Default: False
raw_plot: Optional[bool]: Whether or not to plot the raw gaze data. Default: False
"""
-
title = f'{data.daq_type} '
if heatmap:
title += 'Heatmap '
@@ -162,12 +165,14 @@ def visualize_gaze(
img = plt.imread(img_path)
channels = data.channels
- left_eye_channel_map = [1 if channel in left_keys else 0 for channel in channels]
+ left_eye_channel_map = [
+ 1 if channel in left_keys else 0 for channel in channels]
left_eye_data, _, _ = data.by_channel_map(left_eye_channel_map)
left_eye_x = left_eye_data[0]
left_eye_y = left_eye_data[1]
- right_eye_channel_map = [1 if channel in right_keys else 0 for channel in channels]
+ right_eye_channel_map = [
+ 1 if channel in right_keys else 0 for channel in channels]
right_eye_data, _, _ = data.by_channel_map(right_eye_channel_map)
right_eye_x = right_eye_data[0]
right_eye_y = right_eye_data[1]
@@ -220,7 +225,8 @@ def visualize_gaze(
plt.title(f'{title}Plot')
if save_path is not None:
- plt.savefig(f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
+ plt.savefig(
+ f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
if show:
plt.show()
@@ -337,7 +343,8 @@ def visualize_gaze_inquiries(
plt.title(f'{title}Plot')
if save_path is not None:
- plt.savefig(f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
+ plt.savefig(
+ f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
if show:
plt.show()
@@ -419,7 +426,8 @@ def visualize_pupil_size(
plt.title(f'{title}Plot')
if save_path is not None:
- plt.savefig(f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
+ plt.savefig(
+ f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
if show:
plt.show()
@@ -487,7 +495,8 @@ def visualize_centralized_data(
plt.title(f'{title}Plot')
if save_path is not None:
- plt.savefig(f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
+ plt.savefig(
+ f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
if show:
plt.show()
@@ -593,7 +602,8 @@ def visualize_results_all_symbols(
# Plot an ellipse to show the Gaussian component
angle = np.arctan(u[1] / u[0])
angle = 180.0 * angle / np.pi # convert to degrees
- ell = Ellipse(mean, v[0], v[1], angle=180.0 + angle, color='navy')
+ ell = Ellipse(mean, v[0], v[1],
+ angle=180.0 + angle, color='navy')
ell.set_clip_box(ax)
ell.set_alpha(0.5)
ax.add_artist(ell)
@@ -606,7 +616,8 @@ def visualize_results_all_symbols(
plt.title(f'{title}Plot')
if save_path is not None:
- plt.savefig(f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
+ plt.savefig(
+ f"{save_path}/{title.lower().replace(' ', '_')}plot.png", dpi=fig.dpi)
if show:
plt.show()
@@ -655,7 +666,8 @@ def visualize_csv_eeg_triggers(trigger_col: Optional[int] = None):
def visualize_joint_average(
epochs: Tuple[Epochs],
labels: List[str],
- plot_joint_times: Optional[List[float]] = [-0.1, 0, 0.2, 0.3, 0.35, 0.4, 0.5],
+ plot_joint_times: Optional[List[float]
+ ] = [-0.1, 0, 0.2, 0.3, 0.35, 0.4, 0.5],
save_path: Optional[str] = None,
show: Optional[bool] = False) -> List[Figure]:
"""Visualize Joint Average.
@@ -676,7 +688,8 @@ def visualize_joint_average(
Returns:
List of figures generated
"""
- assert len(epochs) == len(labels), "The number of epochs must match labels in Visualize Joint Average"
+ assert len(epochs) == len(
+ labels), "The number of epochs must match labels in Visualize Joint Average"
figs = []
for i, label in enumerate(labels):
@@ -740,12 +753,14 @@ def visualize_session_data(
# extract all relevant parameters
trial_window = parameters.get("trial_window")
- raw_data = load_raw_data(str(Path(session_path, f'{RAW_DATA_FILENAME}.csv')))
+ raw_data = load_raw_data(
+ str(Path(session_path, f'{RAW_DATA_FILENAME}.csv')))
channels = raw_data.channels
sample_rate = raw_data.sample_rate
daq_type = raw_data.daq_type
- transform_params: ERPTransformParams = parameters.instantiate(ERPTransformParams)
+ transform_params: ERPTransformParams = parameters.instantiate(
+ ERPTransformParams)
devices.load(Path(session_path, DEFAULT_DEVICE_SPEC_FILENAME))
device_spec = devices.preconfigured_device(daq_type)
@@ -763,12 +778,14 @@ def visualize_session_data(
trigger_targetness, trigger_timing, _ = trigger_decoder(
offset=device_spec.static_offset,
trigger_path=f"{session_path}/{TRIGGER_FILENAME}",
- exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION],
+ exclusion=[TriggerType.PREVIEW,
+ TriggerType.EVENT, TriggerType.FIXATION],
device_type='EEG',
)
assert "nontarget" in trigger_targetness, "No nontarget triggers found."
assert "target" in trigger_targetness, "No target triggers found."
- assert len(trigger_targetness) == len(trigger_timing), "Trigger targetness and timing must be the same length."
+ assert len(trigger_targetness) == len(
+ trigger_timing), "Trigger targetness and timing must be the same length."
labels = [0 if label == 'nontarget' else 1 for label in trigger_targetness]
channel_map = analysis_channels(channels, device_spec)
@@ -807,7 +824,8 @@ def visualize_gaze_accuracies(accuracy_dict: Dict[str, np.ndarray],
ax.set_title(title + str(round(accuracy, 2)))
if save_path is not None:
- plt.savefig(f"{save_path}/{title.lower().replace(' ', '_').replace(':', '')}plot.png", dpi=fig.dpi)
+ plt.savefig(
+ f"{save_path}/{title.lower().replace(' ', '_').replace(':', '')}plot.png", dpi=fig.dpi)
if show:
plt.show()
@@ -818,6 +836,9 @@ def visualize_gaze_accuracies(accuracy_dict: Dict[str, np.ndarray],
def erp():
+ """ERP Visualization CLI.
+ This function is used to visualize ERP data after a session.
+ """
import argparse
parser = argparse.ArgumentParser(description='Visualize ERP data')
diff --git a/bcipy/io/README.md b/bcipy/io/README.md
index 8b04bd539..0381ec7fa 100644
--- a/bcipy/io/README.md
+++ b/bcipy/io/README.md
@@ -2,6 +2,6 @@
The BciPy IO module contains functionality for loading and saving data in various formats. This includes the ability to convert data to BIDS format, load and save raw data, and load and save triggers.
-- `convert`: functionality for converting the bcipy raw data output to other formats (currently, BrainVision and EDF), and for converting the raw data to BIDS format.
+- `convert`: functionality for converting the BciPy raw data output to other formats (currently, BrainVision and EDF), and for converting the raw data to BIDS format.
- `load`: methods for loading most BciPy data formats, including raw data and triggers.
- `save`: methods for saving BciPy data in supported formats.
diff --git a/bcipy/io/convert.py b/bcipy/io/convert.py
index 9f308fa63..baa88b583 100644
--- a/bcipy/io/convert.py
+++ b/bcipy/io/convert.py
@@ -27,6 +27,7 @@
class ConvertFormat(Enum):
+ """Enumeration of supported data conversion formats for BciPy raw data output."""
BV = 'BrainVision'
EDF = 'EDF'
@@ -34,14 +35,17 @@ class ConvertFormat(Enum):
EEGLAB = 'EEGLAB'
def __str__(self):
+ """Return the string representation of the format."""
return self.value
@staticmethod
def all():
+ """Return a list of all ConvertFormat enum members."""
return [format for format in ConvertFormat]
@staticmethod
def values():
+ """Return a list of all ConvertFormat values as strings."""
return [format.value for format in ConvertFormat]
@@ -89,7 +93,8 @@ def convert_to_bids(
try:
os.mkdir(output_dir)
except OSError as e:
- raise OSError(f"Failed to create output directory={output_dir}") from e
+ raise OSError(
+ f"Failed to create output directory={output_dir}") from e
if format not in ConvertFormat.all():
raise ValueError(f"Unsupported format={format}")
if line_frequency not in [50, 60]:
@@ -190,16 +195,20 @@ def convert_eyetracking_to_bids(
"""
# check that the raw data path exists
if not os.path.exists(raw_data_path):
- raise FileNotFoundError(f"Raw eye tracking data path={raw_data_path} does not exist")
+ raise FileNotFoundError(
+ f"Raw eye tracking data path={raw_data_path} does not exist")
if not os.path.exists(output_dir):
- raise FileNotFoundError(f"Output directory={output_dir} does not exist")
+ raise FileNotFoundError(
+ f"Output directory={output_dir} does not exist")
found_files = glob.glob(f"{raw_data_path}/eyetracker*.csv")
if len(found_files) == 0:
- raise FileNotFoundError(f"No raw eye tracking data found in directory={raw_data_path}")
+ raise FileNotFoundError(
+ f"No raw eye tracking data found in directory={raw_data_path}")
if len(found_files) > 1:
- raise ValueError(f"Multiple raw eye tracking data files found in directory={raw_data_path}")
+ raise ValueError(
+ f"Multiple raw eye tracking data files found in directory={raw_data_path}")
eye_tracking_file = found_files[0]
logger.info(f"Found raw eye tracking data file={eye_tracking_file}")
@@ -405,7 +414,8 @@ def convert_to_mne(
# if remove_system_channels is True, exclude the system and trigger channels (last two channels)
if remove_system_channels:
channel_map = [1] * (len(raw_data.channels) - 2)
- channel_map.extend([0, 0]) # exclude the system and trigger channels
+ # exclude the system and trigger channels
+ channel_map.extend([0, 0])
else:
channel_map = [1] * len(raw_data.channels)
@@ -413,7 +423,8 @@ def convert_to_mne(
# if no channel types provided, assume all channels are eeg
if not channel_types:
- logger.warning("No channel types provided. Assuming all channels are EEG.")
+ logger.warning(
+ "No channel types provided. Assuming all channels are EEG.")
channel_types = ['eeg'] * len(channels)
# check that number of channel types matches number of channels in the case custom channel types are provided
@@ -470,8 +481,10 @@ def norm_to_tobii(norm_units: Tuple[float, float]) -> Tuple[float, float]:
and (1, 1) the lower right corner.
"""
# check that the coordinates are within the bounds of the screen
- assert norm_units[0] >= -1 and norm_units[0] <= 1, "X coordinate must be between -1 and 1"
- assert norm_units[1] >= -1 and norm_units[1] <= 1, "Y coordinate must be between -1 and 1"
+ assert norm_units[0] >= - \
+ 1 and norm_units[0] <= 1, "X coordinate must be between -1 and 1"
+ assert norm_units[1] >= - \
+ 1 and norm_units[1] <= 1, "Y coordinate must be between -1 and 1"
# convert PsychoPy norm units to Tobii units
tobii_x = (norm_units[0] / 2) + 0.5
@@ -505,7 +518,8 @@ def BIDS_to_MNE(
"""
# Check if the BIDS root path exists
if not os.path.exists(bids_root_path):
- raise FileNotFoundError(f"BIDS root path '{bids_root_path}' does not exist.")
+ raise FileNotFoundError(
+ f"BIDS root path '{bids_root_path}' does not exist.")
logger.info(f"Searching for BIDS data in '{bids_root_path}'...")
# Get the sessions from the BIDS root path using the session label (e.g., 'ses' or 'session')
@@ -521,12 +535,14 @@ def BIDS_to_MNE(
tasks=task_name,
)
if not bid_paths:
- raise FileNotFoundError(f"No matching BIDS files found in '{bids_root_path}'.")
+ raise FileNotFoundError(
+ f"No matching BIDS files found in '{bids_root_path}'.")
raw_data = []
for bid_path in bid_paths:
if task_name and bid_path.task != task_name:
- logger.debug(f"Skipping file '{bid_path}' due to task name [{task_name}] mismatch.")
+ logger.debug(
+ f"Skipping file '{bid_path}' due to task name [{task_name}] mismatch.")
continue
logger.info(f"Reading BIDS file: {bid_path}")
diff --git a/bcipy/io/demo/demo_convert.py b/bcipy/io/demo/demo_convert.py
index bce775562..5195b44d6 100644
--- a/bcipy/io/demo/demo_convert.py
+++ b/bcipy/io/demo/demo_convert.py
@@ -54,7 +54,8 @@ def load_historical_bcipy_data(directory: str, experiment_id: str) -> List[BciPy
extracted_task_time = task_run.name.split('_')[7]
# add the tasks to the list with the time as the key
- run_tasks[extracted_task_time] = [task_run, f'{extracted_task_paradigm}{extracted_task_mode}']
+ run_tasks[extracted_task_time] = [
+ task_run, f'{extracted_task_paradigm}{extracted_task_mode}']
# sort the tasks by time
sorted_tasks = sorted(run_tasks.items())
@@ -88,7 +89,8 @@ def convert_experiment_to_bids(
experiment_id=experiment_id)
# Use for data post-2.0rc4
- experiment_data = load_bcipy_data(directory, experiment_id, excluded_tasks=EXCLUDED_TASKS)
+ experiment_data = load_bcipy_data(
+ directory, experiment_id, excluded_tasks=EXCLUDED_TASKS)
if not output_dir:
output_dir = directory
@@ -120,8 +122,10 @@ def convert_experiment_to_bids(
task_name=data.task_name
)
except Exception as e:
- print(f"Error converting eye tracker data for {data.path} - {e}")
- errors.append(f"Error converting eye tracker data for {data.path}")
+ print(
+ f"Error converting eye tracker data for {data.path} - {e}")
+ errors.append(
+ f"Error converting eye tracker data for {data.path}")
except Exception as e:
print(f"Error converting {data.path} - {e}")
@@ -131,7 +135,8 @@ def convert_experiment_to_bids(
if errors:
print(f"Errors converting the following data: {errors}")
- print(f"\nData converted to BIDS format in {output_dir}/bids_{experiment_id}/")
+ print(
+ f"\nData converted to BIDS format in {output_dir}/bids_{experiment_id}/")
print("--------------------")
@@ -162,7 +167,8 @@ def convert_experiment_to_bids(
path = args.directory
if not path:
- path = ask_directory("Select the directory with data to be converted", strict=True)
+ path = ask_directory(
+ "Select the directory with data to be converted", strict=True)
# convert a study to BIDS format
convert_experiment_to_bids(
diff --git a/bcipy/io/demo/demo_load_BIDS.py b/bcipy/io/demo/demo_load_BIDS.py
index 5f4ac5bd1..942e4a747 100644
--- a/bcipy/io/demo/demo_load_BIDS.py
+++ b/bcipy/io/demo/demo_load_BIDS.py
@@ -28,7 +28,8 @@
task_name = None # Set to None to load all tasks.
raw_data_files = BIDS_to_MNE(path_to_bids, task_name=task_name)
- raw_data = raw_data_files[0] # Get the first raw data object from the list.
+ # Get the first raw data object from the list.
+ raw_data = raw_data_files[0]
# to see where the data is stored, you can use the following command:
# print(raw_data.filenames)
@@ -41,11 +42,13 @@
# EPOCH THE DATA / CREATE ERPS
# epoch the data using the events from the raw data object.
# You can specify the event_id, tmin, tmax, and baseline parameters as needed.
- events = mne.events_from_annotations(raw_data, event_id={'nontarget': 0, 'target': 1})
+ events = mne.events_from_annotations(
+ raw_data, event_id={'nontarget': 0, 'target': 1})
tmin = -0.2
tmax = 0.8
baseline = (None, None) # No baseline correction.
- epochs = mne.Epochs(raw_data, events[0], events[1], tmin, tmax, baseline=baseline, preload=True)
+ epochs = mne.Epochs(
+ raw_data, events[0], events[1], tmin, tmax, baseline=baseline, preload=True)
# Grab the epochs for non-target and target events.
non_target_epochs = epochs['nontarget']
@@ -72,4 +75,5 @@
print(f"An error occurred: {e}")
finally:
print("Demo script completed.")
- breakpoint() # This will pause the script execution and allow you to inspect the variables in the debugger.
+ # This will pause the script execution and allow you to inspect the variables in the debugger.
+ breakpoint()
diff --git a/bcipy/io/load.py b/bcipy/io/load.py
index 2c0157c82..8b4069426 100644
--- a/bcipy/io/load.py
+++ b/bcipy/io/load.py
@@ -1,3 +1,9 @@
+"""Module for loading BciPy data and configuration files.
+
+This module provides functions for loading various types of data used in BciPy,
+including parameters, experiments, signal models, and session data.
+"""
+
# mypy: disable-error-code="arg-type, union-attr"
import json
import logging
@@ -24,17 +30,15 @@
def copy_parameters(path: str = DEFAULT_PARAMETERS_PATH,
destination: Optional[str] = None) -> str:
- """Creates a copy of the given configuration (parameters.json) to the
- given directory and returns the path.
-
- Parameters:
- -----------
- path: str - optional path of parameters file to copy; used default if not provided.
- destination: str - optional destination directory; default is the same
- directory as the default parameters.
+ """Creates a copy of the given configuration (parameters.json) to the given directory.
+
+ Args:
+ path: Optional path of parameters file to copy; uses default if not provided.
+ destination: Optional destination directory; default is the same directory
+ as the default parameters.
+
Returns:
- --------
- path to the new file.
+ str: Path to the new file.
"""
default_dir = str(Path(DEFAULT_PARAMETERS_PATH).parent)
@@ -47,73 +51,91 @@ def copy_parameters(path: str = DEFAULT_PARAMETERS_PATH,
def load_experiments(path: str = f'{DEFAULT_EXPERIMENT_PATH}/{EXPERIMENT_FILENAME}') -> dict:
- """Load Experiments.
+ """Load experiment configurations from a JSON file.
- PARAMETERS
- ----------
- :param: path: string path to the experiments file.
-
- Returns
- -------
- A dictionary of experiments, with the following format:
- { name: { fields : {name: '', required: bool, anonymize: bool}, summary: '' } }
+ Args:
+ path: Path to the experiments file.
+ Returns:
+ dict: Dictionary of experiments with format:
+ {
+ name: {
+ fields: {
+ name: str,
+ required: bool,
+ anonymize: bool
+ },
+ summary: str
+ }
+ }
"""
with open(path, 'r', encoding=DEFAULT_ENCODING) as json_file:
return json.load(json_file)
def extract_mode(bcipy_data_directory: str) -> str:
- """Extract Mode.
+ """Extract the task mode from a BciPy data save directory.
This method extracts the task mode from a BciPy data save directory. This is important for
- trigger conversions and extracting targeteness.
+ trigger conversions and extracting targetness.
- *note*: this is not compatible with older versions of BciPy (pre 1.5.0) where
+ Note:
+ Not compatible with older versions of BciPy (pre 1.5.0) where
the tasks and modes were considered together using integers (1, 2, 3).
- PARAMETERS
- ----------
- :param: bcipy_data_directory: string path to the data directory
+ Args:
+ bcipy_data_directory: Path to the data directory.
+
+ Returns:
+ str: The extracted mode ('calibration' or 'copy_phrase').
+
+ Raises:
+ BciPyCoreException: If no valid mode could be extracted.
"""
directory = bcipy_data_directory.lower()
if 'calibration' in directory:
return 'calibration'
elif 'copy' in directory:
return 'copy_phrase'
- raise BciPyCoreException(f'No valid mode could be extracted from [{directory}]')
+ raise BciPyCoreException(
+ f'No valid mode could be extracted from [{directory}]')
def load_fields(path: str = f'{DEFAULT_FIELD_PATH}/{FIELD_FILENAME}') -> dict:
- """Load Fields.
+ """Load field definitions from a JSON file.
- PARAMETERS
- ----------
- :param: path: string path to the fields file.
+ Args:
+ path: Path to the fields file.
- Returns
- -------
- A dictionary of fields, with the following format:
+ Returns:
+ dict: Dictionary of fields with format:
{
"field_name": {
- "help_text": "",
- "type": ""
+ "help_text": str,
+ "type": str
+ }
}
-
"""
with open(path, 'r', encoding=DEFAULT_ENCODING) as json_file:
return json.load(json_file)
def load_experiment_fields(experiment: dict) -> list:
- """Load Experiment Fields.
+ """Extract field names from an experiment configuration.
- {
- 'fields': [{}, {}],
- 'summary': ''
- }
+ Args:
+ experiment: Dictionary containing experiment configuration with format:
+ {
+ 'fields': [{field_dict}, {field_dict}],
+ 'summary': str
+ }
- Using the experiment dictionary, loop over the field keys and put them in a list.
+ Returns:
+ list: List of field names from the experiment configuration.
+
+ Raises:
+ InvalidExperimentException: If experiment format is incorrect.
+ TypeError: If experiment is not a dictionary.
"""
if isinstance(experiment, dict):
try:
@@ -122,55 +144,66 @@ def load_experiment_fields(experiment: dict) -> list:
raise InvalidExperimentException(
'Experiment is not formatted correctly. It should be passed as a dictionary with the fields and'
f' summary keys. Fields is a list of dictionaries. Summary is a string. \n experiment=[{experiment}]')
- raise TypeError('Unsupported experiment type. It should be passed as a dictionary with the fields and summary keys')
+ raise TypeError(
+ 'Unsupported experiment type. It should be passed as a dictionary with the fields and summary keys')
def load_json_parameters(path: str, value_cast: bool = False) -> Parameters:
- """Load JSON Parameters.
-
- Given a path to a json of parameters, convert to a dictionary and optionally
- cast the type.
-
- Expects the following format:
- "fake_data": {
- "value": "true",
- "section": "bci_config",
- "name": "Fake Data Sessions",
- "helpTip": "If true, fake data server used",
- "recommended": "",
- "editable": "true",
- "type": "bool"
- }
+ """Load and parse parameters from a JSON file.
- PARAMETERS
- ----------
- :param: path: string path to the parameters file.
- :param: value_case: True/False cast values to specified type.
+ Args:
+ path: Path to the parameters file.
+ value_cast: Whether to cast values to their specified types.
- Returns
- -------
- a Parameters object that behaves like a dict.
+ Returns:
+ Parameters: A Parameters object containing the loaded configuration.
+
+ Note:
+ Expected JSON format:
+ {
+ "parameter_name": {
+ "value": str,
+ "section": str,
+ "name": str,
+ "helpTip": str,
+ "recommended": str,
+ "editable": str,
+ "type": str
+ }
+ }
"""
return Parameters(source=path, cast_values=value_cast)
def load_experimental_data(message='', strict=False) -> str:
- filename = ask_directory(prompt=message, strict=strict) # show dialog box and return the path
+ """Show a dialog to select an experimental data directory.
+
+ Args:
+ message: Optional prompt message for the dialog.
+ strict: Whether to enforce strict directory selection.
+
+ Returns:
+ str: Path to the selected directory.
+ """
+ filename = ask_directory(prompt=message, strict=strict)
log.info("Loaded Experimental Data From: %s" % filename)
return filename
def load_signal_models(directory: Optional[str] = None) -> List[SignalModel]:
- """Load all signal models in a given directory.
+ """Load all signal models from a directory.
Models are assumed to have been written using bcipy.helpers.save.save_model
- function and should be serialized as pickled files. Note that reading
- pickled files is a potential security concern so only load from trusted
- directories.
+ function and should be serialized as pickled files.
Args:
- dirname (str, optional): Location of pretrained models. If not
- provided the user will be prompted for a location.
+ directory: Location of pretrained models. User will be prompted if not provided.
+
+ Returns:
+ list: List of loaded SignalModel instances.
+
+ Warning:
+ Reading pickled files is a potential security risk. Only load from trusted directories.
"""
if not directory or Path(directory).is_file():
directory = ask_directory()
@@ -191,9 +224,11 @@ def load_signal_models(directory: Optional[str] = None) -> List[SignalModel]:
def choose_signal_models(device_types: List[str]) -> List[SignalModel]:
"""Prompt the user to load a signal model for each provided device.
- Parameters
- ----------
- device_types - list of device content types (ex. 'EEG')
+ Args:
+ device_types: List of device content types (e.g., 'EEG').
+
+ Returns:
+ list: List of selected SignalModel instances.
"""
return [
model for model in map(choose_signal_model, set(device_types)) if model
@@ -203,11 +238,15 @@ def choose_signal_models(device_types: List[str]) -> List[SignalModel]:
def load_signal_model(file_path: str) -> SignalModel:
"""Load signal model from persisted file.
- Models are assumed to have been written using bcipy.io.save.save_model
- function and should be serialized as pickled files. Note that reading
- pickled files is a potential security concern so only load from trusted
- directories."""
+ Args:
+ file_path: Path to the model file.
+
+ Returns:
+ SignalModel: The loaded signal model.
+ Warning:
+ Reading pickled files is a potential security risk. Only load from trusted sources.
+ """
with open(file_path, "rb") as signal_file:
model = pickle.load(signal_file)
log.info(f"Loading model {model}")
@@ -215,15 +254,15 @@ def load_signal_model(file_path: str) -> SignalModel:
def choose_signal_model(device_type: str) -> Optional[SignalModel]:
- """Present a file dialog prompting the user to select a signal model for
- the given device.
+ """Present a file dialog prompting the user to select a signal model.
- Parameters
- ----------
- device_type - ex. 'EEG' or 'Eyetracker'; this should correspond with
- the content_type of the DeviceSpec of the model.
- """
+ Args:
+ device_type: Device type (e.g., 'EEG' or 'Eyetracker') that should correspond
+ with the content_type of the DeviceSpec of the model.
+ Returns:
+ Optional[SignalModel]: The selected signal model, or None if no selection made.
+ """
file_path = ask_filename(file_types=f"*{SIGNAL_MODEL_FILE_SUFFIX}",
directory=preferences.signal_model_directory,
prompt=f"Select the {device_type} signal model")
@@ -237,7 +276,14 @@ def choose_signal_model(device_type: str) -> Optional[SignalModel]:
def choose_model_paths(device_types: List[str]) -> List[Path]:
- """Select a model for each device and return a list of paths."""
+ """Select a model for each device and return a list of paths.
+
+ Args:
+ device_types: List of device types to load models for.
+
+ Returns:
+ list: List of paths to selected model files.
+ """
return [
ask_filename(file_types=f"*{SIGNAL_MODEL_FILE_SUFFIX}",
directory=preferences.signal_model_directory,
@@ -247,15 +293,16 @@ def choose_model_paths(device_types: List[str]) -> List[Path]:
def choose_csv_file(filename: Optional[str] = None) -> Optional[str]:
- """GUI prompt to select a csv file from the file system.
+ """GUI prompt to select a CSV file from the file system.
- Parameters
- ----------
- - filename : optional filename to use; if provided the GUI is not shown.
+ Args:
+ filename: Optional filename to use; if provided the GUI is not shown.
+
+ Returns:
+ Optional[str]: Path to selected file.
- Returns
- -------
- file name of selected file; throws an exception if the file is not a csv.
+ Raises:
+ Exception: If the selected file is not a CSV file.
"""
if not filename:
filename = ask_filename('*.csv')
@@ -271,25 +318,26 @@ def choose_csv_file(filename: Optional[str] = None) -> Optional[str]:
def load_raw_data(filename: Union[Path, str]) -> RawData:
- """Reads the data (.csv) file written by data acquisition.
+ """Read data from a CSV file written by data acquisition.
- Parameters
- ----------
- - filename : path to the serialized data (csv file)
+ Args:
+ filename: Path to the serialized data (CSV file).
- Returns
- -------
- RawData object with data held in memory
+ Returns:
+ RawData: Object containing the loaded data in memory.
"""
return RawData.load(filename)
def load_users(data_save_loc: str) -> List[str]:
- """Load Users.
+ """Load user directory names from the data path.
- Loads user directory names below experiments from the data path defined and returns them as a list.
- If the save data directory is not found, this method returns an empty list assuming no experiments
- have been run yet.
+ Args:
+ data_save_loc: Path to the data directory.
+
+ Returns:
+ list: List of user IDs found in the directory. Returns empty list if
+ directory not found (assuming no experiments have been run).
"""
try:
bcipy_data = BciPyCollection(data_directory=data_save_loc)
@@ -300,11 +348,14 @@ def load_users(data_save_loc: str) -> List[str]:
def fast_scandir(directory_name: str, return_path: bool = True) -> List[str]:
- """Fast Scan Directory.
+ """Quickly scan a directory for subdirectories.
- directory_name: name of the directory to be scanned
- return_path: whether or not to return the scanned directories as a relative path or name.
- False will return the directory name only.
+ Args:
+ directory_name: Name of the directory to scan.
+ return_path: Whether to return full paths (True) or just names (False).
+
+ Returns:
+ list: List of subdirectory paths or names.
"""
if return_path:
return [f.path for f in os.scandir(directory_name) if f.is_dir()]
@@ -313,17 +364,38 @@ def fast_scandir(directory_name: str, return_path: bool = True) -> List[str]:
class BciPySessionTaskData:
- """Session Task Data.
+ """Class representing data from a single BciPy task session.
- This class is used to represent a single task session. It is used to store the
- path to the task data, as well as the parameters and other information about the task.
+ This class is used to store the path to the task data, as well as parameters
+ and other information about the task.
- //
- protocol.json
- /
- parameters.json
- **task_data**
+ Directory structure:
+ //
+ protocol.json
+ /
+ parameters.json
+ **task_data**
+ Args:
+ path: Path to the session data.
+ user_id: ID of the user who performed the task.
+ experiment_id: ID of the experiment the task belongs to.
+ date_time: Optional timestamp of task execution.
+ date: Optional date of task execution.
+ task_name: Optional name of the executed task.
+ session_id: Session identifier number, defaults to 1.
+ run: Run number within the session, defaults to 1.
+
+ Attributes:
+ user_id: ID of the user who performed the task.
+ experiment_id: ID of the experiment (with underscores removed).
+ session_id: Formatted session ID (zero-padded if < 10).
+ date_time: Timestamp of task execution.
+ date: Date of task execution.
+ run: Formatted run number (zero-padded if < 10).
+ path: Path to the session data.
+ task_name: Name of the executed task.
+ info: Dictionary containing all session information.
"""
def __init__(
@@ -339,7 +411,8 @@ def __init__(
self.user_id = user_id
self.experiment_id = experiment_id.replace('_', '')
- self.session_id = f'0{str(session_id)}' if session_id < 10 else str(session_id)
+ self.session_id = f'0{str(session_id)}' if session_id < 10 else str(
+ session_id)
self.date_time = date_time
self.date = date
self.run = f'0{str(run)}' if run < 10 else str(run)
@@ -364,10 +437,34 @@ def __repr__(self):
class BciPyCollection:
- """BciPy Data.
+ """Class for managing collections of BciPy session task data.
+
+ This class is used to collect data from the data directory and filter based
+ on the provided filters.
- This class is used to represent a full BciPy data collection. It is used to collect data from the
- data directory and filter based on the provided filters.
+ Args:
+ data_directory: Root directory containing BciPy data.
+ experiment_id_filter: Optional filter for specific experiments.
+ user_id_filter: Optional filter for specific users.
+ date_filter: Optional filter for specific dates.
+ date_time_filter: Optional filter for specific timestamps.
+ excluded_tasks: Optional list of task names to exclude.
+ anonymize: Whether to anonymize user data.
+
+ Attributes:
+ data_directory: Root directory containing BciPy data.
+ experiment_id_filter: Filter for specific experiments.
+ user_id_filter: Filter for specific users.
+ date_filter: Filter for specific dates.
+ date_time_filter: Filter for specific timestamps.
+ excluded_tasks: List of task names to exclude.
+ anonymize: Whether to anonymize user data.
+ session_task_data: List of collected BciPySessionTaskData objects.
+ user_paths: List of paths to user directories.
+ date_paths: List of paths to date directories.
+ experiment_paths: List of paths to experiment directories.
+ date_time_paths: List of paths to datetime directories.
+ task_paths: List of paths to task directories.
"""
def __init__(
@@ -407,36 +504,64 @@ def __str__(self):
@property
def users(self) -> List[str]:
+ """Get list of users in the collection.
+
+ Returns:
+ list: List of user IDs.
+ """
return [user.split('/')[-1] for user in self.user_paths]
@property
def experiments(self) -> List[str]:
- experiments = [experiment.split('/')[-1] for experiment in self.experiment_paths]
+ """Get list of unique experiments in the collection.
+
+ Returns:
+ list: List of experiment IDs.
+ """
+ experiments = [experiment.split('/')[-1]
+ for experiment in self.experiment_paths]
# remove duplicates from the list
return list(set(experiments))
@property
def dates(self) -> List[str]:
+ """Get list of unique dates in the collection.
+
+ Returns:
+ list: List of dates.
+ """
dates = [date.split('/')[-1] for date in self.date_paths]
# remove duplicates from the list
return list(set(dates))
@property
def date_times(self) -> List[str]:
- date_times = [date_time.split('/')[-1] for date_time in self.date_time_paths]
+ """Get list of unique timestamps in the collection.
+
+ Returns:
+ list: List of timestamps.
+ """
+ date_times = [date_time.split('/')[-1]
+ for date_time in self.date_time_paths]
# remove duplicates from the list
return list(set(date_times))
@property
def tasks(self) -> List[str]:
+ """Get list of unique tasks in the collection.
+
+ Returns:
+ list: List of task names.
+ """
tasks = [task.task_name for task in self.session_task_data]
# remove duplicates from the list
return list(set(tasks))
def collect(self) -> List[BciPySessionTaskData]:
- """Collect.
+ """Collect BciPy data from the data directory.
- Collects the BciPy data from the data directory and returns a list of BciPySessionTaskData objects.
+ Returns:
+ list: List of BciPySessionTaskData objects representing the experiment data.
"""
if not self.session_task_data:
self.load_tasks()
@@ -455,20 +580,21 @@ def collect(self) -> List[BciPySessionTaskData]:
return self.session_task_data
def load_users(self) -> None:
- """Load Users.
+ """Load user paths from the data directory.
- Walks the data directory and sets the user paths. It will filter by the user id if provided.
+ Walks the data directory and sets the user paths. Filters by user ID if provided.
"""
user_paths = fast_scandir(self.data_directory, return_path=True)
if self.user_id_filter:
- self.user_paths = [user for user in user_paths if self.user_id_filter in user]
+ self.user_paths = [
+ user for user in user_paths if self.user_id_filter in user]
else:
self.user_paths = user_paths
def load_dates(self) -> None:
- """Load Dates.
+ """Load date paths from the data directory.
- Walks the data directory and sets the date paths. It will filter by the date if provided.
+ Walks the data directory and sets the date paths. Filters by date if provided.
"""
if not self.user_paths:
self.load_users()
@@ -476,14 +602,15 @@ def load_dates(self) -> None:
for user in self.user_paths:
data_paths = fast_scandir(user, return_path=True)
if self.date_filter:
- self.date_paths.extend([data for data in data_paths if self.date_filter in data])
+ self.date_paths.extend(
+ [data for data in data_paths if self.date_filter in data])
else:
self.date_paths.extend(data_paths)
def load_experiments(self) -> None:
- """Load Experiments.
+ """Load experiment paths from the data directory.
- Walks the data directory and sets the experiment paths. It will filter by the experiment id if provided.
+ Walks the data directory and sets the experiment paths. Filters by experiment ID if provided.
"""
if not self.date_paths:
self.load_dates()
@@ -491,14 +618,15 @@ def load_experiments(self) -> None:
for date in self.date_paths:
experiment_paths = fast_scandir(date, return_path=True)
if self.experiment_id_filter:
- self.experiment_paths.extend([data for data in experiment_paths if self.experiment_id_filter in data])
+ self.experiment_paths.extend(
+ [data for data in experiment_paths if self.experiment_id_filter in data])
else:
self.experiment_paths.extend(experiment_paths)
def load_date_times(self) -> None:
- """Load Date Times.
+ """Load datetime paths from the data directory.
- Walks the data directory and sets the date time paths. It will filter by the date time if provided.
+ Walks the data directory and sets the datetime paths. Filters by datetime if provided.
"""
if not self.experiment_paths:
self.load_experiments()
@@ -506,22 +634,27 @@ def load_date_times(self) -> None:
for experiment in self.experiment_paths:
data_paths = fast_scandir(experiment, return_path=True)
if self.date_time_filter:
- self.date_time_paths.extend([data for data in data_paths if self.date_time_filter in data])
+ self.date_time_paths.extend(
+ [data for data in data_paths if self.date_time_filter in data])
else:
self.date_time_paths.extend(data_paths)
def sort_tasks(self, tasks: List[str]) -> List[str]:
- """Sort Tasks.
+ """Sort tasks by their timestamp.
+
+ Args:
+ tasks: List of task paths to sort.
- Sorts the tasks in the order they were run using the timestamp at the end of the task path.
+ Returns:
+ list: Sorted list of task paths.
"""
return sorted(tasks, key=lambda x: x.split('_')[-1])
def load_tasks(self) -> None:
- """Load Tasks.
+ """Load task data from the data directory.
Walks the data directory and sets the session_task_data representing the experiment data.
- It will exclude tasks that are in the excluded_tasks list.
+ Excludes tasks that are in the excluded_tasks list.
"""
if not self.date_time_paths:
self.load_date_times()
@@ -575,33 +708,32 @@ def load_bcipy_data(
date_time: Optional[str] = None,
excluded_tasks: Optional[List[str]] = None,
anonymize: bool = False) -> List[BciPySessionTaskData]:
- """Load BciPy Data.
-
- Walks a data directory and returns a list of data paths for the given experiment id, user id, and date.
-
- The BciPy data directory is structured as follows:
- data/
- user_ids/
- dates/
- experiment_ids/
- datetimes/
- protocol.json
- logs/
- tasks/
- raw_data.csv
- triggers.txt
-
- data_directory: the bcipy data directory to walk
- experiment_id: the experiment id to filter by
- user_id: the user id to filter by
- date: the date to filter by
- date_time: the date time to filter by
- excluded_tasks: a list of tasks to exclude from the returned list of experiment data
- anonymize: whether or not to anonymize the user ids
+ """Load BciPy data from a directory.
+
+ Args:
+ data_directory: The BciPy data directory to walk.
+ experiment_id: Optional experiment ID to filter by.
+ user_id: Optional user ID to filter by.
+ date: Optional date to filter by.
+ date_time: Optional datetime to filter by.
+ excluded_tasks: Optional list of tasks to exclude.
+ anonymize: Whether to anonymize user IDs.
Returns:
- --------
- a list of BciPySessionTaskData objects representing the experiment data
+ list: List of BciPySessionTaskData objects representing the experiment data.
+
+ Note:
+ The BciPy data directory is structured as follows:
+ data/
+ user_ids/
+ dates/
+ experiment_ids/
+ datetimes/
+ protocol.json
+ logs/
+ tasks/
+ raw_data.csv
+ triggers.txt
"""
if not excluded_tasks:
excluded_tasks = []
diff --git a/bcipy/io/save.py b/bcipy/io/save.py
index 197adbceb..7578e6c01 100644
--- a/bcipy/io/save.py
+++ b/bcipy/io/save.py
@@ -38,6 +38,17 @@ def save_experiment_data(
fields: dict,
location: str,
name: str) -> str:
+ """Save experiment data to a JSON file.
+
+ Args:
+ experiments (dict): Experiment data to save.
+ fields (dict): Additional fields to save.
+ location (str): Directory to save the file.
+ name (str): Name of the file.
+
+ Returns:
+ str: Path to the saved file.
+ """
return save_json_data(experiments, location, name)
@@ -45,6 +56,16 @@ def save_field_data(
fields: dict,
location: str,
name: str) -> str:
+ """Save field data to a JSON file.
+
+ Args:
+ fields (dict): Field data to save.
+ location (str): Directory to save the file.
+ name (str): Name of the file.
+
+ Returns:
+ str: Path to the saved file.
+ """
return save_json_data(fields, location, name)
@@ -52,6 +73,16 @@ def save_experiment_field_data(
data: dict,
location: str,
name: str) -> str:
+ """Save experiment field data to a JSON file.
+
+ Args:
+ data (dict): Data to save.
+ location (str): Directory to save the file.
+ name (str): Name of the file.
+
+ Returns:
+ str: Path to the saved file.
+ """
return save_json_data(data, location, name)
@@ -95,7 +126,8 @@ def init_save_data_structure(data_save_path: str,
copyfile(parameters, Path(save_directory, DEFAULT_PARAMETERS_FILENAME))
- copyfile(DEFAULT_LM_PARAMETERS_PATH, Path(save_directory, DEFAULT_LM_PARAMETERS_FILENAME))
+ copyfile(DEFAULT_LM_PARAMETERS_PATH, Path(
+ save_directory, DEFAULT_LM_PARAMETERS_FILENAME))
return save_directory
@@ -186,8 +218,8 @@ def save_stimuli_position_info(
screen_info: Dict[str, Any]) -> str:
"""Save stimuli positions and screen info to `path`
- stimuli_position_info: {'A': (0, 0)}
- screen_info: {'screen_size_pixels': [1920, 1080], 'screen_refresh': 160}
+ stimuli_position_info: {'A': (0, 0)}
+ screen_info: {'screen_size_pixels': [1920, 1080], 'screen_refresh': 160}
Parameters
----------
diff --git a/bcipy/io/tests/test_convert.py b/bcipy/io/tests/test_convert.py
index 882779b45..ec26682e1 100644
--- a/bcipy/io/tests/test_convert.py
+++ b/bcipy/io/tests/test_convert.py
@@ -47,8 +47,10 @@ def create_bcipy_session_artifacts(
if isinstance(channels, int):
channels = [CHANNEL_NAMES[i] for i in range(channels)]
- data = sample_data(ch_names=channels, daq_type='SampleDevice', sample_rate=sample_rate, rows=samples)
- devices.register(devices.DeviceSpec('SampleDevice', channels=channels, sample_rate=sample_rate))
+ data = sample_data(ch_names=channels, daq_type='SampleDevice',
+ sample_rate=sample_rate, rows=samples)
+ devices.register(devices.DeviceSpec(
+ 'SampleDevice', channels=channels, sample_rate=sample_rate))
with open(Path(write_dir, TRIGGER_FILENAME), 'w', encoding=DEFAULT_ENCODING) as trg_file:
trg_file.write(trg_data)
@@ -75,7 +77,8 @@ class TestBIDSConversion(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
- self.trg_data, self.data, self.params = create_bcipy_session_artifacts(self.temp_dir)
+ self.trg_data, self.data, self.params = create_bcipy_session_artifacts(
+ self.temp_dir)
def tearDown(self):
shutil.rmtree(self.temp_dir)
@@ -240,8 +243,10 @@ def setUp(self):
'down_sampling_rate': 3
}
self.channels = ['timestamp', 'O1', 'O2', 'Pz', 'TRG', 'lsl_timestamp']
- self.raw_data = RawData('SampleDevice', self.sample_rate, self.channels)
- devices.register(devices.DeviceSpec('SampleDevice', channels=self.channels, sample_rate=self.sample_rate))
+ self.raw_data = RawData(
+ 'SampleDevice', self.sample_rate, self.channels)
+ devices.register(devices.DeviceSpec(
+ 'SampleDevice', channels=self.channels, sample_rate=self.sample_rate))
# generate 100 random samples of data
for _ in range(0, 100):
@@ -425,7 +430,8 @@ def test_tobii_to_norm(self):
self.assertEqual(norm_data, excepted_norm_data)
tobii_data = (1, 1) # bottom right of screen in tobii coordinates
- excepted_norm_data = (1, -1) # bottom right of screen in norm coordinates
+ # bottom right of screen in norm coordinates
+ excepted_norm_data = (1, -1)
norm_data = tobii_to_norm(tobii_data)
self.assertEqual(norm_data, excepted_norm_data)
@@ -443,7 +449,8 @@ def test_tobii_to_norm_raises_error_with_invalid_units(self):
def test_norm_to_tobii(self):
"""Test the norm_to_tobii function"""
norm_data = (0, 0) # center of screen in norm coordinates
- excepted_tobii_data = (0.5, 0.5) # center of screen in tobii coordinates
+ # center of screen in tobii coordinates
+ excepted_tobii_data = (0.5, 0.5)
tobii_data = norm_to_tobii(norm_data)
self.assertEqual(tobii_data, excepted_tobii_data)
@@ -453,7 +460,8 @@ def test_norm_to_tobii(self):
self.assertEqual(tobii_data, excepted_tobii_data)
norm_data = (1, -1) # bottom right of screen in norm coordinates
- excepted_tobii_data = (1, 1) # bottom right of screen in tobii coordinates
+ # bottom right of screen in tobii coordinates
+ excepted_tobii_data = (1, 1)
tobii_data = norm_to_tobii(norm_data)
self.assertEqual(tobii_data, excepted_tobii_data)
@@ -472,7 +480,8 @@ class TestConvertETBIDS(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
- self.trg_data, self.data, self.params = create_bcipy_session_artifacts(self.temp_dir, channels=3)
+ self.trg_data, self.data, self.params = create_bcipy_session_artifacts(
+ self.temp_dir, channels=3)
self.eyetracking_data = sample_data(
ch_names=[
'timestamp',
@@ -482,7 +491,8 @@ def setUp(self):
daq_type='Gaze',
sample_rate=60,
rows=5000)
- devices.register(devices.DeviceSpec('Gaze', channels=['timestamp', 'x', 'y', 'pupil'], sample_rate=60))
+ devices.register(devices.DeviceSpec('Gaze', channels=[
+ 'timestamp', 'x', 'y', 'pupil'], sample_rate=60))
write(self.eyetracking_data, Path(self.temp_dir, 'eyetracker.csv'))
@@ -503,7 +513,8 @@ def test_convert_eyetracking_to_bids_generates_bids_strucutre(self):
# Assert the session directory was created with et
self.assertTrue(os.path.exists(f"{self.temp_dir}/et/"))
# Assert the et tsv file was created with the correct name
- self.assertTrue(os.path.exists(f"{self.temp_dir}/et/sub-01_ses-01_task-TestTask_run-01_eyetracking.tsv"))
+ self.assertTrue(os.path.exists(
+ f"{self.temp_dir}/et/sub-01_ses-01_task-TestTask_run-01_eyetracking.tsv"))
def test_convert_eyetracking_to_bids_reflects_participant_id(self):
"""Test the convert_eyetracking_to_bids function with a participant id"""
@@ -517,7 +528,8 @@ def test_convert_eyetracking_to_bids_reflects_participant_id(self):
)
self.assertTrue(os.path.exists(response))
# Assert the et tsv file was created with the correct name
- self.assertTrue(os.path.exists(f"{self.temp_dir}/et/sub-100_ses-01_task-TestTask_run-01_eyetracking.tsv"))
+ self.assertTrue(os.path.exists(
+ f"{self.temp_dir}/et/sub-100_ses-01_task-TestTask_run-01_eyetracking.tsv"))
def test_convert_eyetracking_to_bids_reflects_session_id(self):
"""Test the convert_eyetracking_to_bids function with a session id"""
@@ -531,7 +543,8 @@ def test_convert_eyetracking_to_bids_reflects_session_id(self):
)
self.assertTrue(os.path.exists(response))
# Assert the et tsv file was created with the correct name
- self.assertTrue(os.path.exists(f"{self.temp_dir}/et/sub-01_ses-100_task-TestTask_run-01_eyetracking.tsv"))
+ self.assertTrue(os.path.exists(
+ f"{self.temp_dir}/et/sub-01_ses-100_task-TestTask_run-01_eyetracking.tsv"))
def test_convert_eyetracking_to_bids_reflects_run_id(self):
"""Test the convert_eyetracking_to_bids function with a run id"""
@@ -545,7 +558,8 @@ def test_convert_eyetracking_to_bids_reflects_run_id(self):
)
self.assertTrue(os.path.exists(response))
# Assert the et tsv file was created with the correct name
- self.assertTrue(os.path.exists(f"{self.temp_dir}/et/sub-01_ses-01_task-TestTask_run-100_eyetracking.tsv"))
+ self.assertTrue(os.path.exists(
+ f"{self.temp_dir}/et/sub-01_ses-01_task-TestTask_run-100_eyetracking.tsv"))
def test_convert_eyetracking_to_bids_reflects_task_name(self):
"""Test the convert_eyetracking_to_bids function with a task name"""
@@ -559,7 +573,8 @@ def test_convert_eyetracking_to_bids_reflects_task_name(self):
)
self.assertTrue(os.path.exists(response))
# Assert the et tsv file was created with the correct name
- self.assertTrue(os.path.exists(f"{self.temp_dir}/et/sub-01_ses-01_task-TestTaskEtc_run-01_eyetracking.tsv"))
+ self.assertTrue(os.path.exists(
+ f"{self.temp_dir}/et/sub-01_ses-01_task-TestTaskEtc_run-01_eyetracking.tsv"))
def test_convert_et_raises_error_with_invalid_data_dir(self):
"""Test the convert_eyetracking_to_bids function raises an error with invalid output directory"""
@@ -633,7 +648,8 @@ def test_successful_conversion(self, mock_read_raw_bids, mock_find_matching_path
mock_bids_path1.task = 'RSVPCalibration'
mock_bids_path2 = MagicMock()
mock_bids_path2.task = 'RSVPCalibration'
- mock_find_matching_paths.return_value = [mock_bids_path1, mock_bids_path2]
+ mock_find_matching_paths.return_value = [
+ mock_bids_path1, mock_bids_path2]
# Create mock Raw objects that read_raw_bids will return
mock_raw1 = MagicMock(spec=mne.io.Raw)
@@ -646,7 +662,8 @@ def test_successful_conversion(self, mock_read_raw_bids, mock_find_matching_path
self.assertIs(result[0], mock_raw1)
self.assertIs(result[1], mock_raw2)
mock_path_exists.assert_called_once_with('/fake/bids/path')
- mock_get_entity_vals.assert_called_once_with('/fake/bids/path', 'session')
+ mock_get_entity_vals.assert_called_once_with(
+ '/fake/bids/path', 'session')
mock_find_matching_paths.assert_called_once()
self.assertEqual(mock_read_raw_bids.call_count, 2)
@@ -673,14 +690,16 @@ def test_no_matching_files(self, mock_find_matching_paths, mock_get_entity_vals,
BIDS_to_MNE('/fake/bids/path')
mock_path_exists.assert_called_once_with('/fake/bids/path')
- mock_get_entity_vals.assert_called_once_with('/fake/bids/path', 'session')
+ mock_get_entity_vals.assert_called_once_with(
+ '/fake/bids/path', 'session')
mock_find_matching_paths.assert_called_once()
@patch('bcipy.io.convert.os.path.exists')
@patch('bcipy.io.convert.get_entity_vals')
@patch('bcipy.io.convert.find_matching_paths')
@patch('bcipy.io.convert.read_raw_bids')
- @patch('bcipy.io.convert.logger') # Mock the logger to prevent actual logging during tests
+ # Mock the logger to prevent actual logging during tests
+ @patch('bcipy.io.convert.logger')
def test_task_filtering(self, mock_logger, mock_read_raw_bids, mock_find_matching_paths,
mock_get_entity_vals, mock_path_exists):
"""Test task name filtering."""
@@ -693,7 +712,8 @@ def test_task_filtering(self, mock_logger, mock_read_raw_bids, mock_find_matchin
mock_bids_path1.task = 'RSVPCalibration'
mock_bids_path2 = MagicMock()
mock_bids_path2.task = 'OtherTask'
- mock_find_matching_paths.return_value = [mock_bids_path1, mock_bids_path2]
+ mock_find_matching_paths.return_value = [
+ mock_bids_path1, mock_bids_path2]
# Create mock Raw object that read_raw_bids will return
mock_raw = MagicMock(spec=mne.io.Raw)
@@ -725,7 +745,8 @@ def test_multiple_sessions(self, mock_read_raw_bids, mock_find_matching_paths,
mock_bids_path2 = MagicMock()
mock_bids_path2.task = 'RSVPCalibration'
mock_bids_path2.session = '02'
- mock_find_matching_paths.return_value = [mock_bids_path1, mock_bids_path2]
+ mock_find_matching_paths.return_value = [
+ mock_bids_path1, mock_bids_path2]
# Create mock Raw objects
mock_raw1 = MagicMock(spec=mne.io.Raw)
@@ -735,7 +756,8 @@ def test_multiple_sessions(self, mock_read_raw_bids, mock_find_matching_paths,
result = BIDS_to_MNE('/fake/bids/path')
self.assertEqual(len(result), 2)
- mock_get_entity_vals.assert_called_once_with('/fake/bids/path', 'session')
+ mock_get_entity_vals.assert_called_once_with(
+ '/fake/bids/path', 'session')
self.assertEqual(mock_read_raw_bids.call_count, 2)
@patch('bcipy.io.convert.os.path.exists')
@@ -777,7 +799,8 @@ def test_debug_logging(self, mock_logger, mock_read_raw_bids, mock_find_matching
mock_bids_path1.task = 'RSVPCalibration'
mock_bids_path2 = MagicMock()
mock_bids_path2.task = 'OtherTask'
- mock_find_matching_paths.return_value = [mock_bids_path1, mock_bids_path2]
+ mock_find_matching_paths.return_value = [
+ mock_bids_path1, mock_bids_path2]
# Mock the debug logging method specifically
mock_logger.debug = MagicMock()
@@ -789,7 +812,8 @@ def test_debug_logging(self, mock_logger, mock_read_raw_bids, mock_find_matching
BIDS_to_MNE('/fake/bids/path', task_name='RSVPCalibration')
# Check if debug was called for skipping a file
- self.assertTrue(any('Skipping' in str(args) for args, _ in mock_logger.debug.call_args_list))
+ self.assertTrue(any('Skipping' in str(args)
+ for args, _ in mock_logger.debug.call_args_list))
if __name__ == '__main__':
diff --git a/bcipy/io/tests/test_load.py b/bcipy/io/tests/test_load.py
index d170efbc3..cdc4cd379 100644
--- a/bcipy/io/tests/test_load.py
+++ b/bcipy/io/tests/test_load.py
@@ -89,7 +89,8 @@ def tearDown(self):
def test_load_experiments_calls_open_with_expected_default(self):
with patch('builtins.open', mock_open(read_data='data')) as mock_file:
load_experiments()
- mock_file.assert_called_with(self.experiments_path, 'r', encoding=DEFAULT_ENCODING)
+ mock_file.assert_called_with(
+ self.experiments_path, 'r', encoding=DEFAULT_ENCODING)
def test_load_experiments_throws_file_not_found_exception_with_invalid_path(self):
with self.assertRaises(FileNotFoundError):
@@ -112,7 +113,8 @@ def tearDown(self):
def test_load_fields_calls_open_with_expected_default(self):
with patch('builtins.open', mock_open(read_data='data')) as mock_file:
load_fields()
- mock_file.assert_called_with(self.fields_path, 'r', encoding=DEFAULT_ENCODING)
+ mock_file.assert_called_with(
+ self.fields_path, 'r', encoding=DEFAULT_ENCODING)
def test_load_fields_throws_file_not_found_exception_with_invalid_path(self):
with self.assertRaises(FileNotFoundError):
@@ -375,7 +377,8 @@ def test_load_bcipy_data_with_experiment_id_filter(self):
total_expected_files = len(self.user_ids) * len(self.dates) * (len(self.experiment_ids) - 1) * \
len(self.datetimes) * len(self.tasks)
- response = load_bcipy_data(self.data_dir, experiment_id=desired_experiment_id)
+ response = load_bcipy_data(
+ self.data_dir, experiment_id=desired_experiment_id)
self.assertEqual(len(response), total_expected_files)
experiments = [file.experiment_id for file in response]
diff --git a/bcipy/io/tests/test_save.py b/bcipy/io/tests/test_save.py
index b272497ee..78b63fff5 100644
--- a/bcipy/io/tests/test_save.py
+++ b/bcipy/io/tests/test_save.py
@@ -99,11 +99,14 @@ def tearDown(self):
shutil.rmtree(self.save_directory)
def test_save_stimuli_position_info_writes_json(self):
- save_stimuli_position_info(self.stimuli_positions, self.save_directory, self.screen_info)
- self.assertTrue(os.path.isfile(os.path.join(self.save_directory, self.filename)))
+ save_stimuli_position_info(
+ self.stimuli_positions, self.save_directory, self.screen_info)
+ self.assertTrue(os.path.isfile(
+ os.path.join(self.save_directory, self.filename)))
def test_save_stimuli_position_info_writes_correct_json(self):
- save_stimuli_position_info(self.stimuli_positions, self.save_directory, self.screen_info)
+ save_stimuli_position_info(
+ self.stimuli_positions, self.save_directory, self.screen_info)
# load the json file
with open(os.path.join(self.save_directory, self.filename)) as f:
data = json.load(f)
diff --git a/bcipy/language/README.md b/bcipy/language/README.md
index c54ca8e35..49d5c4703 100644
--- a/bcipy/language/README.md
+++ b/bcipy/language/README.md
@@ -29,22 +29,22 @@ The language module has the following structure:
The UniformLanguageModel provides equal probabilities for all symbols in the symbol set. This model is useful for evaluating other aspects of the system, such as EEG signal quality, without any influence from a language model.
## NGram Model
+
The NGramLanguageModelAdapter utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`.
For models that import the kenlm module, this must be manually installed using `pip install kenlm==0.1 --global-option="max_order=12"`.
## Causal Model
+
The CausalLanguageModelAdapter class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models (including the domain-adapted figmtu/opt-350m-aac). Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be:
> *['pe', 'anut', '_butter', '_and', '_j', 'el']*
Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution. More details on this process can be found in our paper, [Adapting Large Language Models for Character-based Augmentative and Alternative Communication](https://arxiv.org/abs/2501.10582).
-
## Mixture Model
-The MixtureLanguageModelAdapter class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported.
-
-# Contact Information
-For language model related questions, please contact Dylan Gaines (dcgaines [[at](https://en.wikipedia.org/wiki/At_sign)] mtu.edu) or create an issue.
+The MixtureLanguageModelAdapter class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported.
+## Contact Information
+For language model related questions, please contact Dylan Gaines (dcgaines [[at](https://en.wikipedia.org/wiki/At_sign)] mtu.edu) or create a GitHub issue.
diff --git a/bcipy/language/demo/demo_ngram.py b/bcipy/language/demo/demo_ngram.py
index 54f146460..e02867e3f 100644
--- a/bcipy/language/demo/demo_ngram.py
+++ b/bcipy/language/demo/demo_ngram.py
@@ -57,7 +57,8 @@
# file zebras.txt: 1 sentences, 14 words, 0 OOVs
# 0 zeroprobs, logprob= -15.2391 ppl= 10.374 ppl1= 12.260
sentence = "i l i k e z e b r a s ."
- print(f"Sentence '{sentence}', logprob = {model.score(sentence, bos=True, eos=True):.4f}\n")
+ print(
+ f"Sentence '{sentence}', logprob = {model.score(sentence, bos=True, eos=True):.4f}\n")
# Stateful query going one token at-a-time
# We'll flip flop between two state objects, one is the input and the other is the output
@@ -76,7 +77,8 @@
score = model.BaseScore(state, token, state2)
else:
score = model.BaseScore(state2, token, state)
- print(f"p( {token} | {prev} ...) = {pow(10, score):.6f} [ {score:.6f} ]")
+ print(
+ f"p( {token} | {prev} ...) = {pow(10, score):.6f} [ {score:.6f} ]")
accum += score
prev = token
print(f"sum logprob = {accum:.4f}")
diff --git a/bcipy/language/model/adapter.py b/bcipy/language/model/adapter.py
index 9fb6ff438..20814e6ba 100644
--- a/bcipy/language/model/adapter.py
+++ b/bcipy/language/model/adapter.py
@@ -25,7 +25,8 @@ def predict_character(self, evidence: Union[str, List[str]]) -> List[Tuple]:
"""
if self.symbol_set is None:
- raise InvalidSymbolSetException("symbol set must be set prior to requesting predictions.")
+ raise InvalidSymbolSetException(
+ "symbol set must be set prior to requesting predictions.")
assert self.model is not None, "language model does not exist!"
@@ -61,7 +62,8 @@ def set_symbol_set(self, symbol_set: List[str]) -> None:
self.symbol_set = symbol_set
# LM doesn't care about backspace, needs literal space
- self.model_symbol_set = [' ' if ch is SPACE_CHAR else ch for ch in self.symbol_set]
+ self.model_symbol_set = [
+ ' ' if ch is SPACE_CHAR else ch for ch in self.symbol_set]
self.model_symbol_set.remove(BACKSPACE_CHAR)
self._load_model()
diff --git a/bcipy/language/model/causal.py b/bcipy/language/model/causal.py
index ca493d24c..e76ca7b90 100644
--- a/bcipy/language/model/causal.py
+++ b/bcipy/language/model/causal.py
@@ -38,8 +38,10 @@ def __init__(self,
causal_params = self.parameters['causal']
- self.beam_width = beam_width or int(causal_params['beam_width']['value'])
- self.max_completed = max_completed or int(causal_params['max_completed']['value'])
+ self.beam_width = beam_width or int(
+ causal_params['beam_width']['value'])
+ self.max_completed = max_completed or int(
+ causal_params['max_completed']['value'])
# We optionally load the model from a local directory, but if this is not
# specified, we load a Hugging Face model
diff --git a/bcipy/language/model/mixture.py b/bcipy/language/model/mixture.py
index f337ca677..29748c2a7 100644
--- a/bcipy/language/model/mixture.py
+++ b/bcipy/language/model/mixture.py
@@ -7,9 +7,7 @@
class MixtureLanguageModelAdapter(LanguageModelAdapter):
- """
- Character language model that mixes any combination of other models
- """
+ """Character language model that mixes any combination of other models."""
supported_lm_types = MixtureLanguageModel.supported_lm_types
@@ -25,7 +23,8 @@ def __init__(self,
lm_params - list of dictionaries to pass as parameters for each model's instantiation
"""
- MixtureLanguageModel.validate_parameters(lm_types, lm_weights, lm_params)
+ MixtureLanguageModel.validate_parameters(
+ lm_types, lm_weights, lm_params)
self._load_parameters()
@@ -38,7 +37,8 @@ def __init__(self,
if type == "NGRAM":
params["lm_path"] = f"{LM_PATH}/{params['lm_path']}"
- MixtureLanguageModel.validate_parameters(self.lm_types, self.lm_weights, self.lm_params)
+ MixtureLanguageModel.validate_parameters(
+ self.lm_types, self.lm_weights, self.lm_params)
def _load_model(self) -> None:
"""Load the model itself using stored parameters"""
diff --git a/bcipy/language/model/oracle.py b/bcipy/language/model/oracle.py
index c1f1157ea..e581f460c 100644
--- a/bcipy/language/model/oracle.py
+++ b/bcipy/language/model/oracle.py
@@ -115,8 +115,7 @@ def predict_character(self, evidence: Union[str, List[str]]) -> List[Tuple]:
reverse=True)
def _next_target(self, spelled_text: str) -> Optional[str]:
- """Computes the next target letter based on the currently spelled_text.
- """
+ """Computes the next target letter based on the currently spelled_text."""
len_spelled = len(spelled_text)
len_task = len(self.task_text)
diff --git a/bcipy/language/tests/test_causal.py b/bcipy/language/tests/test_causal.py
index c91beacc8..c645ff6ae 100644
--- a/bcipy/language/tests/test_causal.py
+++ b/bcipy/language/tests/test_causal.py
@@ -20,7 +20,8 @@ def setUpClass(cls):
cls.gpt2_model = CausalLanguageModelAdapter(lang_model_name="gpt2")
cls.gpt2_model.set_symbol_set(DEFAULT_SYMBOL_SET)
- cls.opt_model = CausalLanguageModelAdapter(lang_model_name="facebook/opt-125m")
+ cls.opt_model = CausalLanguageModelAdapter(
+ lang_model_name="facebook/opt-125m")
cls.opt_model.set_symbol_set(DEFAULT_SYMBOL_SET)
@pytest.mark.slow
@@ -59,7 +60,8 @@ def test_invalid_model_name(self):
def test_invalid_model_path(self):
"""Test that the proper exception is thrown if given an invalid lm_path"""
with self.assertRaises(InvalidLanguageModelException):
- lm = CausalLanguageModelAdapter(lang_model_name="gpt2", lm_path="./phonypath/")
+ lm = CausalLanguageModelAdapter(
+ lang_model_name="gpt2", lm_path="./phonypath/")
lm.set_symbol_set(DEFAULT_SYMBOL_SET)
def test_non_mutable_evidence(self):
@@ -173,7 +175,8 @@ def test_opt_predict_middle_of_word(self):
def test_gpt2_phrase(self):
"""Test that a phrase can be used for input with gpt2 model"""
- symbol_probs = self.gpt2_model.predict_character(list("does_it_make_sen"))
+ symbol_probs = self.gpt2_model.predict_character(
+ list("does_it_make_sen"))
most_likely_sym, _prob = sorted(symbol_probs,
key=itemgetter(1),
reverse=True)[0]
@@ -181,7 +184,8 @@ def test_gpt2_phrase(self):
def test_opt_phrase(self):
"""Test that a phrase can be used for input with Facebook opt model"""
- symbol_probs = self.opt_model.predict_character(list("does_it_make_sen"))
+ symbol_probs = self.opt_model.predict_character(
+ list("does_it_make_sen"))
most_likely_sym, _prob = sorted(symbol_probs,
key=itemgetter(1),
reverse=True)[0]
@@ -205,14 +209,18 @@ def test_opt_multiple_spaces(self):
def test_gpt2_nonzero_prob(self):
"""Test that all letters in the alphabet have nonzero probability except for backspace"""
- symbol_probs = self.gpt2_model.predict_character(list("does_it_make_sens"))
- prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR]
+ symbol_probs = self.gpt2_model.predict_character(
+ list("does_it_make_sens"))
+ prob_values = [item[1]
+ for item in symbol_probs if item[0] != BACKSPACE_CHAR]
for value in prob_values:
self.assertTrue(value > 0)
def test_opt_nonzero_prob(self):
"""Test that all letters in the alphabet have nonzero probability except for backspace"""
- symbol_probs = self.opt_model.predict_character(list("does_it_make_sens"))
- prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR]
+ symbol_probs = self.opt_model.predict_character(
+ list("does_it_make_sens"))
+ prob_values = [item[1]
+ for item in symbol_probs if item[0] != BACKSPACE_CHAR]
for value in prob_values:
self.assertTrue(value > 0)
diff --git a/bcipy/language/tests/test_mixture.py b/bcipy/language/tests/test_mixture.py
index 107cfc2cc..4fe0ecc7c 100644
--- a/bcipy/language/tests/test_mixture.py
+++ b/bcipy/language/tests/test_mixture.py
@@ -21,7 +21,8 @@ def setUpClass(cls):
dirname = os.path.dirname(__file__) or '.'
cls.kenlm_path = "lm_dec19_char_tiny_12gram.kenlm"
print(cls.kenlm_path)
- cls.lm_params = [{"lm_path": cls.kenlm_path}, {"lang_model_name": "gpt2"}]
+ cls.lm_params = [{"lm_path": cls.kenlm_path},
+ {"lang_model_name": "gpt2"}]
cls.lmodel = MixtureLanguageModelAdapter(lm_types=["NGRAM", "CAUSAL"], lm_weights=[0.5, 0.5],
lm_params=cls.lm_params)
cls.lmodel.set_symbol_set(DEFAULT_SYMBOL_SET)
@@ -154,6 +155,7 @@ def test_multiple_spaces(self):
def test_nonzero_prob(self):
"""Test that all letters in the alphabet have nonzero probability except for backspace"""
symbol_probs = self.lmodel.predict_character(list("does_it_make_sens"))
- prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR]
+ prob_values = [item[1]
+ for item in symbol_probs if item[0] != BACKSPACE_CHAR]
for value in prob_values:
self.assertTrue(value > 0)
diff --git a/bcipy/main.py b/bcipy/main.py
index ada3a8f6b..310218495 100644
--- a/bcipy/main.py
+++ b/bcipy/main.py
@@ -1,3 +1,8 @@
+"""Main entry point for BciPy application.
+
+This module provides the main function to initialize and run a BCI task or experiment.
+"""
+
import argparse
import logging
import multiprocessing
@@ -22,32 +27,33 @@ def bci_main(
visualize: bool = True,
fake: bool = False,
task: Optional[Type[Task]] = None) -> bool:
- """BCI Main.
+ """Initialize and run a BCI task or experiment.
- The BCI main function will initialize a save folder, construct needed information
- and execute the task. This is the main connection between any UI and
+ The BCI main function initializes a save folder, constructs needed information
+ and executes the task. This is the main connection between any UI and
running the app.
- A Task or Experiment ID must be provided to run the task. If a task is provided, the experiment
- ID will be ignored.
-
- It may also be invoked via tha command line.
- Ex. `bcipy` this will default parameters, mode, user, and type.
+ Args:
+ parameter_location: Location of parameters file to use.
+ user: Name of the user.
+ experiment_id: Name of the experiment. If task is provided, this will be ignored.
+ alert: Whether to alert the user when the task is complete.
+ visualize: Whether to visualize data at the end of a task.
+ fake: Whether to use fake acquisition data during the session. If None, the
+ fake data will be determined by the parameters file.
+ task: Registered bcipy Task to execute. If None, the task will be determined by the
+ experiment protocol.
- You can pass it those attributes with flags, if desired.
- Ex. `bcipy --user "bci_user" --task "RSVP Calibration"
+ Returns:
+ bool: True if the task executed successfully, False otherwise.
+ Raises:
+ BciPyCoreException: If no experiment or task is provided.
- Input:
- parameter_location (str): location of parameters file to use
- user (str): name of the user
- experiment_id (str): Name of the experiment. If task is provided, this will be ignored.
- alert (bool): whether to alert the user when the task is complete
- visualize (bool): whether to visualize data at the end of a task
- fake (bool): whether to use fake acquisition data during the session. If None, the
- fake data will be determined by the parameters file.
- task (Task): registered bcipy Task to execute. If None, the task will be determined by the
- experiment protocol.
+ Examples:
+ Command line usage:
+ `bcipy` - uses default parameters, mode, user, and type
+ `bcipy --user "bci_user" --task "RSVP Calibration"`
"""
logger.info('Starting BciPy...')
logger.info(
@@ -104,11 +110,21 @@ def bci_main(
return True
-def bcipy_main() -> None: # pragma: no cover
- """BciPy Main.
+def bcipy_main() -> None:
+ """Command line interface for running BciPy experiments and tasks.
+
+ This function provides a command line interface for running registered experiment
+ tasks in BciPy. It handles argument parsing and delegates execution to bci_main.
+
+ Args:
+ None
+
+ Returns:
+ None
- Command line interface used for running a registered experiment task in BciPy. To see what
- is available use the --help flag.
+ Note:
+ Use the --help flag to see available options.
+ Windows machines require multiprocessing support which is initialized here.
"""
# Needed for windows machines
multiprocessing.freeze_support()
diff --git a/bcipy/parameters/devices.json b/bcipy/parameters/devices.json
index 66caab5fb..9da3281e0 100644
--- a/bcipy/parameters/devices.json
+++ b/bcipy/parameters/devices.json
@@ -57,6 +57,25 @@
"status": "active",
"static_offset": 0.1
},
+ {
+ "name": "DSI-7",
+ "content_type": "EEG",
+ "channels": [
+ { "name": "Pz", "label": "Pz", "units": "microvolts", "type": "EEG" },
+ { "name": "F4", "label": "F4", "units": "microvolts", "type": "EEG" },
+ { "name": "C4", "label": "C4", "units": "microvolts", "type": "EEG" },
+ { "name": "P4", "label": "P4", "units": "microvolts", "type": "EEG" },
+ { "name": "P3", "label": "P3", "units": "microvolts", "type": "EEG" },
+ { "name": "C3", "label": "C3", "units": "microvolts", "type": "EEG" },
+ { "name": "F3", "label": "F3", "units": "microvolts", "type": "EEG" },
+ { "name": "TRG", "label": "TRG", "units": "microvolts", "type": "EEG" }
+ ],
+ "sample_rate": 300,
+ "description": "Wearable Sensing DSI-7",
+ "excluded_from_analysis": ["TRG"],
+ "status": "active",
+ "static_offset": 0.1
+ },
{
"name": "DSI-Flex",
"content_type": "EEG",
diff --git a/bcipy/preferences.py b/bcipy/preferences.py
index 84b9ade5c..cb80c8840 100644
--- a/bcipy/preferences.py
+++ b/bcipy/preferences.py
@@ -1,83 +1,126 @@
-"""Module for recording and loading application state and user preferences."""
+"""Module for recording and loading application state and user preferences.
+
+This module provides functionality for storing and retrieving user preferences
+and application state between sessions using a JSON-based storage system.
+"""
import json
from pathlib import Path
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Type
from bcipy.config import BCIPY_ROOT, DEFAULT_ENCODING, PREFERENCES_PATH
class Pref:
- """Preference descriptor. When a class attribute is initialized as a Pref,
- values will be stored and retrieved from an 'entries' dict initialized in
- the instance.
+ """A descriptor class for managing preferences.
+
+ When a class attribute is initialized as a Pref, values will be stored and
+ retrieved from an 'entries' dict initialized in the instance.
+ For more information on descriptors, see:
https://docs.python.org/3/howto/descriptor.html
- Parameters
- ----------
- default - default value assigned to the attribute.
+ Args:
+ default: Default value assigned to the attribute if not found in entries.
"""
- def __init__(self, default: Optional[Any] = None):
+ def __init__(self, default: Optional[Any] = None) -> None:
self.default = default
- self.name = None
+ self.name: Optional[str] = None
+
+ def __set_name__(self, owner: Type[Any], name: str) -> None:
+ """Assign the Pref descriptor to a class attribute.
- def __set_name__(self, owner, name):
- """Called when the class assigns a Pref to a class attribute."""
+ Args:
+ owner: The class that owns this descriptor.
+ name: The name of the attribute this descriptor is assigned to.
+ """
self.name = name
- def __get__(self, instance, owner=None):
- """Retrieve the value from the dict of entries."""
+ def __get__(self, instance: Any, owner: Optional[Type[Any]] = None) -> Any:
+ """Return the value from the dict of entries.
+
+ Args:
+ instance: The instance that this descriptor is accessed from.
+ owner: The class that owns this descriptor.
+
+ Returns:
+ The value stored in entries or the default value.
+ """
+ if instance is None:
+ return self
return instance.entries.get(self.name, self.default)
- def __set__(self, instance, value):
- """Stores the given value in the entries dict keyed on the attribute
- name."""
+ def __set__(self, instance: Any, value: Any) -> None:
+ """Store the given value in the entries dict keyed on the attribute name.
+
+ Args:
+ instance: The instance that this descriptor is accessed from.
+ value: The value to store in the entries dict.
+ """
instance.entries[self.name] = value
instance.save()
class Preferences:
- """User preferences persisted to disk to retain application state between
- work sessions.
+ """User preferences persisted to disk to retain application state between work sessions.
- Parameters
- ----------
- filename - optional file used for persisting entries.
+ This class manages user preferences by storing them in a JSON file on disk. It provides
+ methods for loading, saving, and accessing preference values.
+
+ Attributes:
+ signal_model_directory: Directory containing signal models.
+ last_directory: Last accessed directory, defaults to BCIPY_ROOT.
+
+ Args:
+ filename: Optional file used for persisting entries.
"""
signal_model_directory = Pref()
last_directory = Pref(default=str(BCIPY_ROOT))
def __init__(self, filename: str = PREFERENCES_PATH) -> None:
self.filename = filename
- self.entries: Dict[Any, Any] = {}
+ self.entries: Dict[str, Any] = {}
self.load()
- def load(self):
- """Load preference data from the persisted file."""
+ def load(self) -> None:
+ """Load preference data from the persisted file.
+
+ Reads the JSON file specified by self.filename and populates the entries
+ dictionary with the stored preferences.
+ """
if Path(self.filename).is_file():
with open(self.filename, 'r',
encoding=DEFAULT_ENCODING) as json_file:
for key, val in json.load(json_file).items():
self.entries[key] = val
- def save(self):
- """Write preferences to disk."""
+ def save(self) -> None:
+ """Write preferences to disk.
+
+ Saves the current entries dictionary to the JSON file specified by
+ self.filename.
+ """
with open(self.filename, 'w', encoding=DEFAULT_ENCODING) as json_file:
json.dump(self.entries, json_file, ensure_ascii=False, indent=2)
- def get(self, name: str):
- """Get preference by name"""
+ def get(self, name: str) -> Optional[Any]:
+ """Get preference by name.
+
+ Args:
+ name: Name of the preference to retrieve.
+
+ Returns:
+ The preference value if found, None otherwise.
+ """
return self.entries.get(name, None)
- def set(self, name: str, value: Any, persist: bool = True):
+ def set(self, name: str, value: Any, persist: bool = True) -> None:
"""Set a preference and save the result.
- Parameters
- ----------
- name - name of the preference
- value - value associated with the given name
- persist - flag indicating whether to immediately save the result.
+ Args:
+ name: Name of the preference.
+ value: Value associated with the given name.
+ persist: Flag indicating whether to immediately save the result.
Default is True.
"""
self.entries[name] = value
diff --git a/bcipy/signal/evaluate/artifact.py b/bcipy/signal/evaluate/artifact.py
index c88d56a41..1d8f69a21 100644
--- a/bcipy/signal/evaluate/artifact.py
+++ b/bcipy/signal/evaluate/artifact.py
@@ -7,6 +7,7 @@
from typing import List, Optional, Tuple, Union
import mne
+from mne import Annotations
import bcipy.acquisition.devices as devices
from bcipy.acquisition.devices import DeviceSpec
@@ -25,8 +26,6 @@
mne.set_log_level('WARNING')
log = getLogger(SESSION_LOG_FILENAME)
-from mne import Annotations
-
class DefaultArtifactParameters(Enum):
"""Default Artifact Parameters.
@@ -182,7 +181,8 @@ def __init__(
self.session_triggers = mne.Annotations(
self.trigger_time, [self.trial_duration] * len(self.trigger_time), self.trigger_description)
- assert len(device_spec.channel_specs) > 0, 'DeviceSpec used must have channels. None found.'
+ assert len(
+ device_spec.channel_specs) > 0, 'DeviceSpec used must have channels. None found.'
self.units = device_spec.channel_specs[0].units
log.info(f'Artifact detection using {self.units} units.')
assert self.units in self.supported_units, \
@@ -195,7 +195,8 @@ def __init__(
self.detect_eog = detect_eog
self.semi_automatic = semi_automatic
- log.info(f'Artifact detection with {self.detect_voltage=}, {self.detect_eog=}, {self.semi_automatic=}')
+ log.info(
+ f'Artifact detection with {self.detect_voltage=}, {self.detect_eog=}, {self.semi_automatic=}')
self.save_path = save_path
@@ -252,7 +253,8 @@ def label_artifacts(
log.info(f'Bad channels detected: {bad_channels}')
if voltage_annotations:
- log.info(f'Voltage violation events found: {len(voltage_annotations)}')
+ log.info(
+ f'Voltage violation events found: {len(voltage_annotations)}')
annotations += voltage_annotations
self.voltage_annotations = voltage_annotations
@@ -290,7 +292,8 @@ def save_artifacts(self, overwrite: bool = False) -> None:
f'{self.save_path}/{DefaultArtifactParameters.ARTIFACT_LABELLED_FILENAME.value}',
overwrite=overwrite)
else:
- log.info('Artifact cannot be saved, artifact analysis has been done yet.')
+ log.info(
+ 'Artifact cannot be saved, artifact analysis has been done yet.')
def raw_data_to_mne(self, raw_data: RawData, volts: bool = False) -> mne.io.RawArray:
"""Convert the raw data to an MNE RawArray."""
@@ -345,8 +348,10 @@ def label_eog_events(
log.info('No eye channels provided. Cannot detect EOG artifacts.')
return None
- log.info(f'Using blink threshold of {threshold} for channels {self.eye_channels}.')
- eog_events = mne.preprocessing.find_eog_events(self.mne_data, ch_name=self.eye_channels, thresh=threshold)
+ log.info(
+ f'Using blink threshold of {threshold} for channels {self.eye_channels}.')
+ eog_events = mne.preprocessing.find_eog_events(
+ self.mne_data, ch_name=self.eye_channels, thresh=threshold)
# eog_events = mne.preprocessing.ica_find_eog_events(raw) TODO compare to ICA
if len(eog_events) > 0:
@@ -354,7 +359,8 @@ def label_eog_events(
onsets = eog_events[:, 0] / self.mne_data.info['sfreq'] - preblink
durations = [postblink + preblink] * len(eog_events)
descriptions = [label] * len(eog_events)
- blink_annotations = mne.Annotations(onsets, durations, descriptions)
+ blink_annotations = mne.Annotations(
+ onsets, durations, descriptions)
return blink_annotations, eog_events
return None
@@ -416,7 +422,8 @@ def label_voltage_events(
bad_percent=self.percent_bad,
peak=peak[0])
if len(peak_voltage_annotations) > 0:
- log.info(f'Peak voltage events found: {len(peak_voltage_annotations)}')
+ log.info(
+ f'Peak voltage events found: {len(peak_voltage_annotations)}')
onsets, durations, descriptions = self.concat_annotations(
peak_voltage_annotations,
pre_event,
@@ -430,7 +437,8 @@ def label_voltage_events(
flat_voltage_annotations, bad_channels2 = mne.preprocessing.annotate_amplitude(
self.mne_data, min_duration=flat[1], bad_percent=self.percent_bad, flat=flat[0])
if len(flat_voltage_annotations) > 0:
- log.info(f'Flat voltage events found: {len(flat_voltage_annotations)}')
+ log.info(
+ f'Flat voltage events found: {len(flat_voltage_annotations)}')
onsets, durations, descriptions = self.concat_annotations(
flat_voltage_annotations,
pre_event,
@@ -567,14 +575,16 @@ def write_mne_annotations(
# loop through the sessions, pausing after each one to allow for manual stopping
if session.is_dir():
print(f'Processing {session}')
- prompt = input('Hit enter to continue or type "skip" to skip processing: ')
+ prompt = input(
+ 'Hit enter to continue or type "skip" to skip processing: ')
if prompt != 'skip':
# load the parameters from the data directory
parameters = load_json_parameters(
f'{session}/{DEFAULT_PARAMETERS_FILENAME}', value_cast=True)
# load the raw data from the data directory
- raw_data = load_raw_data(str(Path(session, f'{RAW_DATA_FILENAME}.csv')))
+ raw_data = load_raw_data(
+ str(Path(session, f'{RAW_DATA_FILENAME}.csv')))
type_amp = raw_data.daq_type
# load the triggers
@@ -582,7 +592,8 @@ def write_mne_annotations(
trigger_type, trigger_timing, trigger_label = trigger_decoder(
offset=0.1,
trigger_path=f"{session}/{TRIGGER_FILENAME}",
- exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION],
+ exclusion=[TriggerType.PREVIEW,
+ TriggerType.EVENT, TriggerType.FIXATION],
)
triggers = (trigger_type, trigger_timing, trigger_label)
else:
diff --git a/bcipy/signal/evaluate/fusion.py b/bcipy/signal/evaluate/fusion.py
index ab06c01b7..ffeb71688 100644
--- a/bcipy/signal/evaluate/fusion.py
+++ b/bcipy/signal/evaluate/fusion.py
@@ -52,10 +52,12 @@ def calculate_eeg_gaze_fusion_acc(
gaze_acc: accuracy of the gaze model only
fusion_acc: accuracy of the fusion
"""
- logger.info(f"Calculating EEG [{eeg_model.name}] and Gaze [{gaze_model.name}] model fusion accuracy.")
+ logger.info(
+ f"Calculating EEG [{eeg_model.name}] and Gaze [{gaze_model.name}] model fusion accuracy.")
# Extract relevant session information from parameters file
trial_window = parameters.get("trial_window", (0.0, 0.5))
- window_length = trial_window[1] - trial_window[0] # eeg window length, in seconds
+ # eeg window length, in seconds
+ window_length = trial_window[1] - trial_window[0]
prestim_length = parameters.get("prestim_length")
trials_per_inquiry = parameters.get("stim_length")
@@ -64,7 +66,8 @@ def calculate_eeg_gaze_fusion_acc(
buffer = int(parameters.get("task_buffer_length") / 2)
# Get signal filtering information
- transform_params: ERPTransformParams = parameters.instantiate(ERPTransformParams)
+ transform_params: ERPTransformParams = parameters.instantiate(
+ ERPTransformParams)
downsample_rate = transform_params.down_sampling_rate
static_offset = device_spec_eeg.static_offset
@@ -119,10 +122,12 @@ def calculate_eeg_gaze_fusion_acc(
target_symbols = [trigger_symbols[idx]
for idx, targetness in enumerate(trigger_targetness_gaze) if targetness == 'prompt']
total_len = trials_per_inquiry + 1 # inquiry length + the prompt symbol
- inq_start = trigger_timing_gaze[1::total_len] # inquiry start times, exluding prompt and fixation
+ # inquiry start times, exluding prompt and fixation
+ inq_start = trigger_timing_gaze[1::total_len]
# update the trigger timing list to account for the initial trial window
- corrected_trigger_timing = [timing + trial_window[0] for timing in trigger_timing]
+ corrected_trigger_timing = [timing + trial_window[0]
+ for timing in trigger_timing]
erp_data, _fs_eeg = eeg_data.by_channel()
trajectory_data, _fs_eye = gaze_data.by_channel()
@@ -153,18 +158,24 @@ def calculate_eeg_gaze_fusion_acc(
)
# More EEG preprocessing:
- eeg_inquiries, fs = filter_inquiries(eeg_inquiries, default_transform, eeg_sample_rate)
- eeg_inquiry_timing = update_inquiry_timing(eeg_inquiry_timing, downsample_rate)
+ eeg_inquiries, fs = filter_inquiries(
+ eeg_inquiries, default_transform, eeg_sample_rate)
+ eeg_inquiry_timing = update_inquiry_timing(
+ eeg_inquiry_timing, downsample_rate)
trial_duration_samples = int(window_length * fs)
# More gaze preprocessing:
- inquiry_length = gaze_inquiries_list[0].shape[1] # number of time samples in each inquiry
+ # number of time samples in each inquiry
+ inquiry_length = gaze_inquiries_list[0].shape[1]
predefined_dimensions = 4 # left_x, left_y, right_x, right_y
- preprocessed_gaze_data = np.zeros((len(gaze_inquiries_list), predefined_dimensions, inquiry_length))
+ preprocessed_gaze_data = np.zeros(
+ (len(gaze_inquiries_list), predefined_dimensions, inquiry_length))
# Extract left_x, left_y, right_x, right_y for each inquiry
for j in range(len(gaze_inquiries_list)):
- left_eye, right_eye, _, _, _, _ = extract_eye_info(gaze_inquiries_list[j])
- preprocessed_gaze_data[j] = np.concatenate((left_eye.T, right_eye.T,), axis=0)
+ left_eye, right_eye, _, _, _, _ = extract_eye_info(
+ gaze_inquiries_list[j])
+ preprocessed_gaze_data[j] = np.concatenate(
+ (left_eye.T, right_eye.T,), axis=0)
preprocessed_gaze_dict = {i: [] for i in symbol_set}
for i in symbol_set:
@@ -172,8 +183,10 @@ def calculate_eeg_gaze_fusion_acc(
if len(gaze_inquiries_dict[i]) == 0:
continue
for j in range(len(gaze_inquiries_dict[i])):
- left_eye, right_eye, _, _, _, _ = extract_eye_info(gaze_inquiries_dict[i][j])
- preprocessed_gaze_dict[i].append((np.concatenate((left_eye.T, right_eye.T), axis=0)))
+ left_eye, right_eye, _, _, _, _ = extract_eye_info(
+ gaze_inquiries_dict[i][j])
+ preprocessed_gaze_dict[i].append(
+ (np.concatenate((left_eye.T, right_eye.T), axis=0)))
preprocessed_gaze_dict[i] = np.array(preprocessed_gaze_dict[i])
# Find the time averages for each symbol:
@@ -195,12 +208,14 @@ def calculate_eeg_gaze_fusion_acc(
preprocessed_gaze_dict[sym][j],
temp)) # Delta_t = X_t - mu
centralized_data_dict[sym] = np.array(centralized_data_dict[sym])
- time_average_per_symbol[sym] = np.mean(np.array(time_average_per_symbol[sym]), axis=0)
+ time_average_per_symbol[sym] = np.mean(
+ np.array(time_average_per_symbol[sym]), axis=0)
# Take the time average of the gaze data:
centralized_gaze_data = np.zeros_like(preprocessed_gaze_data)
for i, (_, sym) in enumerate(zip(preprocessed_gaze_data, target_symbols)):
- centralized_gaze_data[i] = gaze_model.subtract_mean(preprocessed_gaze_data[i], time_average_per_symbol[sym])
+ centralized_gaze_data[i] = gaze_model.subtract_mean(
+ preprocessed_gaze_data[i], time_average_per_symbol[sym])
"""
Calculate the accuracy of the fusion of EEG and Gaze models. Use the number of iterations to change bootstraping.
@@ -217,10 +232,13 @@ def calculate_eeg_gaze_fusion_acc(
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [est. {remaining}][ela. {elapsed}]\n",
colour='MAGENTA')
for _progress in progress_bar:
- progress_bar.set_description(f"Running iteration {_progress + 1}/{n_iterations}")
+ progress_bar.set_description(
+ f"Running iteration {_progress + 1}/{n_iterations}")
# Pick a train and test dataset (that consists of non-train elements) until test dataset is not empty:
- train_indices = resample(list(range(selection_length)), replace=True, n_samples=100)
- test_indices = np.array([x for x in list(range(selection_length)) if x not in train_indices])
+ train_indices = resample(
+ list(range(selection_length)), replace=True, n_samples=100)
+ test_indices = np.array(
+ [x for x in list(range(selection_length)) if x not in train_indices])
if len(test_indices) == 0:
break
@@ -254,7 +272,8 @@ def calculate_eeg_gaze_fusion_acc(
# extract train and test indices for gaze data:
centralized_gaze_data_train = centralized_gaze_data[train_indices]
# gaze_train_labels = np.array([target_symbols[i] for i in train_indices])
- gaze_data_test = preprocessed_gaze_data[test_indices] # test set is NOT centralized
+ # test set is NOT centralized
+ gaze_data_test = preprocessed_gaze_data[test_indices]
gaze_test_labels = np.array([target_symbols[i] for i in test_indices])
# generate a tuple that matches the index of the symbol with the symbol itself:
symbol_to_index = {symbol: i for i, symbol in enumerate(symbol_set)}
@@ -283,14 +302,17 @@ def calculate_eeg_gaze_fusion_acc(
except BaseException:
# Singular matrix, using pseudo-inverse instead
eps = 10e-3 # add a small value to the diagonal to make the matrix invertible
- inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(len(cov_matrix)) * eps)
+ inv_cov_matrix = np.linalg.inv(
+ cov_matrix + np.eye(len(cov_matrix)) * eps)
# inv_cov_matrix = np.linalg.pinv(cov_matrix + np.eye(len(cov_matrix))*eps)
denominator_gaze = 0
# Given the test data, compute the log likelihood ratios for each symbol,
# from eeg and gaze models:
- eeg_log_likelihoods = np.zeros((len(gaze_data_test), (len(symbol_set))))
- gaze_log_likelihoods = np.zeros((len(gaze_data_test), (len(symbol_set))))
+ eeg_log_likelihoods = np.zeros(
+ (len(gaze_data_test), (len(symbol_set))))
+ gaze_log_likelihoods = np.zeros(
+ (len(gaze_data_test), (len(symbol_set))))
# Save the max posterior and the second max posterior for each test point:
target_posteriors_gaze = np.zeros((len(gaze_data_test), 2))
@@ -306,23 +328,29 @@ def calculate_eeg_gaze_fusion_acc(
for idx, sym in enumerate(symbol_set):
# skip if there is no training example from the symbol
if time_average_per_symbol[sym] == []:
- gaze_log_likelihoods[test_idx, idx] = -100000 # set a very small value
+ gaze_log_likelihoods[test_idx, idx] = - \
+ 100000 # set a very small value
else:
- central_data = gaze_model.subtract_mean(test_data, time_average_per_symbol[sym])
- flattened_data = central_data.reshape((inquiry_length * predefined_dimensions,))
+ central_data = gaze_model.subtract_mean(
+ test_data, time_average_per_symbol[sym])
+ flattened_data = central_data.reshape(
+ (inquiry_length * predefined_dimensions,))
flattened_data *= units
diff = flattened_data - reshaped_mean
diff_list.append(diff)
- numerator = -np.dot(diff.T, np.dot(inv_cov_matrix, diff)) / 2
+ numerator = - \
+ np.dot(diff.T, np.dot(inv_cov_matrix, diff)) / 2
numerator_gaze_list.append(numerator)
unnormalized_log_likelihood_gaze = numerator - denominator_gaze
- gaze_log_likelihoods[test_idx, idx] = unnormalized_log_likelihood_gaze
+ gaze_log_likelihoods[test_idx,
+ idx] = unnormalized_log_likelihood_gaze
normalized_posterior_gaze_only = np.exp(
gaze_log_likelihoods[test_idx, :]) / np.sum(np.exp(gaze_log_likelihoods[test_idx, :]))
# Find the max likelihood:
max_like_gaze = np.argmax(normalized_posterior_gaze_only)
- posterior_of_true_label_gaze = normalized_posterior_gaze_only[symbol_to_index[gaze_test_labels[test_idx]]]
+ posterior_of_true_label_gaze = normalized_posterior_gaze_only[
+ symbol_to_index[gaze_test_labels[test_idx]]]
top_competitor_gaze = np.sort(normalized_posterior_gaze_only)[-2]
target_posteriors_gaze[test_idx, 0] = posterior_of_true_label_gaze
target_posteriors_gaze[test_idx, 1] = top_competitor_gaze
@@ -335,7 +363,8 @@ def calculate_eeg_gaze_fusion_acc(
end = (test_idx + 1) * trials_per_inquiry
eeg_tst_data = preprocessed_test_eeg[:, start:end, :]
inq_sym = inquiry_symbols_test[start: end]
- eeg_likelihood_ratios = eeg_model.compute_likelihood_ratio(eeg_tst_data, inq_sym, symbol_set)
+ eeg_likelihood_ratios = eeg_model.compute_likelihood_ratio(
+ eeg_tst_data, inq_sym, symbol_set)
unnormalized_log_likelihood_eeg = np.log(eeg_likelihood_ratios)
eeg_log_likelihoods[test_idx, :] = unnormalized_log_likelihood_eeg
normalized_posterior_eeg_only = np.exp(
@@ -343,7 +372,8 @@ def calculate_eeg_gaze_fusion_acc(
max_like_eeg = np.argmax(normalized_posterior_eeg_only)
top_competitor_eeg = np.sort(normalized_posterior_eeg_only)[-2]
- posterior_of_true_label_eeg = normalized_posterior_eeg_only[symbol_to_index[gaze_test_labels[test_idx]]]
+ posterior_of_true_label_eeg = normalized_posterior_eeg_only[
+ symbol_to_index[gaze_test_labels[test_idx]]]
target_posteriors_eeg[test_idx, 0] = posterior_of_true_label_eeg
target_posteriors_eeg[test_idx, 1] = top_competitor_eeg
@@ -351,7 +381,8 @@ def calculate_eeg_gaze_fusion_acc(
counter_eeg += 1
# Bayesian fusion update and decision making:
- log_unnormalized_posterior = np.log(eeg_likelihood_ratios) + gaze_log_likelihoods[test_idx, :]
+ log_unnormalized_posterior = np.log(
+ eeg_likelihood_ratios) + gaze_log_likelihoods[test_idx, :]
unnormalized_posterior = np.exp(log_unnormalized_posterior)
denominator = np.sum(unnormalized_posterior)
posterior = unnormalized_posterior / denominator # normalized posterior
@@ -360,7 +391,8 @@ def calculate_eeg_gaze_fusion_acc(
top_competitor_fusion = np.sort(log_posterior)[-2]
posterior_of_true_label_fusion = posterior[symbol_to_index[gaze_test_labels[test_idx]]]
- target_posteriors_fusion[test_idx, 0] = posterior_of_true_label_fusion
+ target_posteriors_fusion[test_idx,
+ 0] = posterior_of_true_label_fusion
target_posteriors_fusion[test_idx, 1] = top_competitor_fusion
if symbol_set[max_posterior] == gaze_test_labels[test_idx]:
counter_fusion += 1
@@ -369,9 +401,12 @@ def calculate_eeg_gaze_fusion_acc(
if posterior.any() == np.nan:
break
- eeg_acc_in_iteration = float("{:.3f}".format(counter_eeg / len(test_indices)))
- gaze_acc_in_iteration = float("{:.3f}".format(counter_gaze / len(test_indices)))
- fusion_acc_in_iteration = float("{:.3f}".format(counter_fusion / len(test_indices)))
+ eeg_acc_in_iteration = float(
+ "{:.3f}".format(counter_eeg / len(test_indices)))
+ gaze_acc_in_iteration = float(
+ "{:.3f}".format(counter_gaze / len(test_indices)))
+ fusion_acc_in_iteration = float(
+ "{:.3f}".format(counter_fusion / len(test_indices)))
eeg_acc.append(eeg_acc_in_iteration)
gaze_acc.append(gaze_acc_in_iteration)
fusion_acc.append(fusion_acc_in_iteration)
diff --git a/bcipy/signal/model/classifier.py b/bcipy/signal/model/classifier.py
index 61b629b10..c9df576d8 100644
--- a/bcipy/signal/model/classifier.py
+++ b/bcipy/signal/model/classifier.py
@@ -71,16 +71,20 @@ def fit(self, x, y, p=[]):
# in order to make the ndarray readable from MATLAB side. There are
# two arrays, [0] for the correctness, choose it
# Class means
- self.mean_i = [np.mean(x[np.where(y == i)[0]], axis=0) for i in self.class_i]
+ self.mean_i = [np.mean(x[np.where(y == i)[0]], axis=0)
+ for i in self.class_i]
# Normalized x
- norm_vec = [x[np.where(y == self.class_i[i])[0]] - self.mean_i[i] for i in range(len(self.class_i))]
+ norm_vec = [x[np.where(y == self.class_i[i])[0]] - self.mean_i[i]
+ for i in range(len(self.class_i))]
# Outer product of data matrix, Xi'Xi for each class
- self.S_i = [np.dot(np.transpose(norm_vec[i]), norm_vec[i]) for i in range(len(self.class_i))]
+ self.S_i = [np.dot(np.transpose(norm_vec[i]), norm_vec[i])
+ for i in range(len(self.class_i))]
# Sample covariances are calculated Si/Ni for each class
- self.cov_i = [self.S_i[i] / self.N_i[i] for i in range(len(self.class_i))]
+ self.cov_i = [self.S_i[i] / self.N_i[i]
+ for i in range(len(self.class_i))]
# Sample covariance of total data
self.S = np.zeros((self.k, self.k))
@@ -90,7 +94,8 @@ def fit(self, x, y, p=[]):
# Set priors
if len(p) == 0:
- prior = np.asarray([np.sum(y == self.class_i[i]) for i in range(len(self.class_i))], dtype=float)
+ prior = np.asarray([np.sum(y == self.class_i[i])
+ for i in range(len(self.class_i))], dtype=float)
self.prior_i = np.divide(prior, np.sum(prior))
else:
self.prior_i = p
@@ -110,13 +115,15 @@ def regularize(self, param): # TODO: what if no param passed?
# Shrinked class covariances
shr_cov_i = [
- ((1 - self.lam) * self.S_i[i] + self.lam * self.S) / ((1 - self.lam) * self.N_i[i] + self.lam * self.N)
+ ((1 - self.lam) * self.S_i[i] + self.lam * self.S) /
+ ((1 - self.lam) * self.N_i[i] + self.lam * self.N)
for i in range(len(self.class_i))
]
# Regularized class covariances
reg_cov_i = [
- ((1 - self.gam) * shr_cov_i[i] + self.gam / self.k * np.trace(shr_cov_i[i]) * np.eye(self.k))
+ ((1 - self.gam) * shr_cov_i[i] + self.gam /
+ self.k * np.trace(shr_cov_i[i]) * np.eye(self.k))
for i in range(len(self.class_i))
]
@@ -156,7 +163,8 @@ def get_prob(self, x):
# Every constant at the end of score calculation is omitted.
# This is why we omit log det of class regularized covariances.
- evidence = np.dot(zero_mean, np.dot(self.inv_reg_cov_i[i], zero_mean))
+ evidence = np.dot(zero_mean, np.dot(
+ self.inv_reg_cov_i[i], zero_mean))
neg_log_l[s][i] = -0.5 * evidence + np.log(self.prior_i[i])
diff --git a/bcipy/signal/model/cross_validation.py b/bcipy/signal/model/cross_validation.py
index 65587a40d..d29145113 100644
--- a/bcipy/signal/model/cross_validation.py
+++ b/bcipy/signal/model/cross_validation.py
@@ -11,7 +11,7 @@
def cost_cross_validation_auc(model, opt_el, x, y, param, k_folds=10,
split='uniform'):
- """ Minimize cost of the overall -AUC.
+ """Minimize cost of the overall -AUC.
Cost function: given a particular architecture (model). Fits the
parameters to the folds with leave one fold out procedure. Calculates
scores for the validation fold. Concatenates all calculated scores
@@ -31,7 +31,8 @@ def cost_cross_validation_auc(model, opt_el, x, y, param, k_folds=10,
-auc(float): negative AUC value for current setup
sc_h(ndarray[float]): scores computed for each validation fold
y_valid_h(ndarray[int]): labels of the scores for each validation fold
- y_valid_h[i] is basically the label for sc_h[i] """
+ y_valid_h[i] is basically the label for sc_h[i]
+ """
num_samples = x.shape[1]
fold_len = np.floor(float(num_samples) / k_folds)
diff --git a/bcipy/signal/model/density_estimation.py b/bcipy/signal/model/density_estimation.py
index 7e88442b9..28aa119be 100644
--- a/bcipy/signal/model/density_estimation.py
+++ b/bcipy/signal/model/density_estimation.py
@@ -17,11 +17,13 @@ class KernelDensityEstimate:
"""
def __init__(self, scores: Optional[np.array] = None, kernel="gaussian", num_cls=2):
- bandwidth = 1.0 if scores is None else self._compute_bandwidth(scores, scores.shape[0])
+ bandwidth = 1.0 if scores is None else self._compute_bandwidth(
+ scores, scores.shape[0])
self.logger = logging.getLogger(SESSION_LOG_FILENAME)
self.logger.info(f"KDE. bandwidth={bandwidth}, kernel={kernel}")
self.num_cls = num_cls
- self.list_den_est = [KernelDensity(bandwidth=bandwidth, kernel=kernel) for _ in range(self.num_cls)]
+ self.list_den_est = [KernelDensity(
+ bandwidth=bandwidth, kernel=kernel) for _ in range(self.num_cls)]
def _compute_bandwidth(self, scores: np.array, num_items: int):
"""Estimate bandwidth parameter using Silverman's rule of thumb.
@@ -34,7 +36,8 @@ def _compute_bandwidth(self, scores: np.array, num_items: int):
Returns:
float: rule-of-thumb bandwidth parameter for KDE
"""
- bandwidth = 0.9 * min(np.std(scores), iqr(scores) / 1.34) * np.power(num_items, -0.2)
+ bandwidth = 0.9 * min(np.std(scores), iqr(scores) /
+ 1.34) * np.power(num_items, -0.2)
return bandwidth
def fit(self, x, y):
@@ -61,7 +64,8 @@ def transform(self, x):
Where N and c denotes number of samples and classes
Returns:
val(ndarray[float]): N x c log-likelihood array
- respectively."""
+ respectively.
+ """
# Calculate likelihoods for each density estimate
val = []
diff --git a/bcipy/signal/model/dimensionality_reduction.py b/bcipy/signal/model/dimensionality_reduction.py
index c2bff96e3..8ec52b868 100644
--- a/bcipy/signal/model/dimensionality_reduction.py
+++ b/bcipy/signal/model/dimensionality_reduction.py
@@ -28,9 +28,11 @@ class ChannelWisePrincipalComponentAnalysis:
def __init__(self, n_components: Optional[float] = None, random_state: Optional[int] = None, num_ch: int = 1):
self.num_ch = num_ch
- self.list_pca = [PCA(n_components=n_components, random_state=random_state) for _ in range(self.num_ch)]
+ self.list_pca = [PCA(n_components=n_components,
+ random_state=random_state) for _ in range(self.num_ch)]
self.logger = logging.getLogger(SESSION_LOG_FILENAME)
- self.logger.info(f"PCA. n_components={n_components}, random_state={random_state}, num_ch={num_ch}")
+ self.logger.info(
+ f"PCA. n_components={n_components}, random_state={random_state}, num_ch={num_ch}")
def fit(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> None:
"""Fit PCA to each channel of data.
diff --git a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py
index f438e6df7..a585309c7 100644
--- a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py
+++ b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py
@@ -85,16 +85,19 @@ def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray):
def evaluate_likelihood(self, data: np.ndarray, symbols: List[str],
symbol_set: List[str]) -> np.ndarray:
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.evaluate_likelihood()")
+ raise SignalException(
+ "must use model.fit() before model.evaluate_likelihood()")
gaze_log_likelihoods = np.zeros((len(symbol_set)))
# Clip the pre-saved centralized data to the length of our test data
cent_data = self.centralized_data[:, :, :data.shape[1]]
- reshaped_data = cent_data.reshape((len(cent_data), data.shape[0] * data.shape[1]))
+ reshaped_data = cent_data.reshape(
+ (len(cent_data), data.shape[0] * data.shape[1]))
cov_matrix = np.cov(reshaped_data, rowvar=False)
reshaped_mean = np.mean(reshaped_data, axis=0)
eps = 10e-1 # add a small value to the diagonal to make the cov matrix invertible
- inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(len(cov_matrix)) * eps)
+ inv_cov_matrix = np.linalg.inv(
+ cov_matrix + np.eye(len(cov_matrix)) * eps)
for idx, sym in enumerate(symbol_set):
if self.time_average[sym] == []:
@@ -179,7 +182,8 @@ def __init__(self, num_components=4, random_state=0, *args, **kwargs):
self.ready_to_predict = False
def fit(self, train_data: np.ndarray):
- model = GaussianMixture(n_components=self.num_components, random_state=self.random_state, init_params='kmeans')
+ model = GaussianMixture(n_components=self.num_components,
+ random_state=self.random_state, init_params='kmeans')
model.fit(train_data)
self.model = model
@@ -202,7 +206,8 @@ def evaluate(self, predictions, true_labels) -> np.ndarray:
--------
accuracy_per_symbol: accuracy per symbol
'''
- accuracy_per_symbol = np.sum(predictions == true_labels) / len(predictions) * 100
+ accuracy_per_symbol = np.sum(
+ predictions == true_labels) / len(predictions) * 100
self.acc = accuracy_per_symbol
return accuracy_per_symbol
@@ -242,7 +247,8 @@ def predict_proba(self, test_data: np.ndarray) -> np.ndarray:
'''
data_length, _ = test_data.shape
- likelihoods = np.zeros((data_length, self.num_components), dtype=object)
+ likelihoods = np.zeros(
+ (data_length, self.num_components), dtype=object)
# Find the likelihoods by insterting the test data into the pdf of each component
for i in range(data_length):
@@ -250,17 +256,20 @@ def predict_proba(self, test_data: np.ndarray) -> np.ndarray:
mu = self.means[k]
sigma = self.covs[k]
- likelihoods[i, k] = stats.multivariate_normal.pdf(test_data[i], mu, sigma)
+ likelihoods[i, k] = stats.multivariate_normal.pdf(
+ test_data[i], mu, sigma)
return likelihoods
def evaluate_likelihood(self, data: np.ndarray, symbols: List[str],
symbol_set: List[str]) -> np.ndarray:
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.evaluate_likelihood()")
+ raise SignalException(
+ "must use model.fit() before model.evaluate_likelihood()")
data_length, _ = data.shape
- likelihoods = np.zeros((data_length, self.num_components), dtype=object)
+ likelihoods = np.zeros(
+ (data_length, self.num_components), dtype=object)
# Find the likelihoods by insterting the test data into the pdf of each component
for i in range(data_length):
@@ -268,7 +277,8 @@ def evaluate_likelihood(self, data: np.ndarray, symbols: List[str],
mu = self.means[k]
sigma = self.covs[k]
- likelihoods[i, k] = stats.multivariate_normal.pdf(data[i], mu, sigma)
+ likelihoods[i, k] = stats.multivariate_normal.pdf(
+ data[i], mu, sigma)
return likelihoods
diff --git a/bcipy/signal/model/offline_analysis.py b/bcipy/signal/model/offline_analysis.py
index e82b706a3..2f6474547 100644
--- a/bcipy/signal/model/offline_analysis.py
+++ b/bcipy/signal/model/offline_analysis.py
@@ -33,7 +33,8 @@
filter_inquiries, get_default_transform)
log = logging.getLogger(SESSION_LOG_FILENAME)
-logging.basicConfig(level=logging.INFO, format="[%(threadName)-9s][%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
+logging.basicConfig(level=logging.INFO,
+ format="[%(threadName)-9s][%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
def subset_data(data: np.ndarray, labels: np.ndarray, test_size: float, random_state: int = 0, swap_axes: bool = True):
@@ -150,12 +151,14 @@ def analyze_erp(
)
# update the trigger timing list to account for the initial trial window
- corrected_trigger_timing = [timing + trial_window[0] for timing in trigger_timing]
+ corrected_trigger_timing = [timing + trial_window[0]
+ for timing in trigger_timing]
# Channel map can be checked from raw_data.csv file or the devices.json located in the acquisition module
# The timestamp column [0] is already excluded.
channel_map = analysis_channels(channels, device_spec)
- channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1]
+ channels_used = [channels[i]
+ for i, keep in enumerate(channel_map) if keep == 1]
log.info(f'Channels used in analysis: {channels_used}')
data, fs = erp_data.by_channel()
@@ -175,7 +178,8 @@ def analyze_erp(
inquiries, fs = filter_inquiries(inquiries, default_transform, sample_rate)
inquiry_timing = update_inquiry_timing(inquiry_timing, downsample_rate)
trial_duration_samples = int(window_length * fs)
- data = model.reshaper.extract_trials(inquiries, trial_duration_samples, inquiry_timing)
+ data = model.reshaper.extract_trials(
+ inquiries, trial_duration_samples, inquiry_timing)
# define the training classes using integers, where 0=nontargets/1=targets
labels = inquiry_labels.flatten().tolist()
@@ -196,7 +200,8 @@ def analyze_erp(
try:
# Using an 80/20 split, report on balanced accuracy
if estimate_balanced_acc:
- train_data, test_data, train_labels, test_labels = subset_data(data, labels, test_size=0.2)
+ train_data, test_data, train_labels, test_labels = subset_data(
+ data, labels, test_size=0.2)
dummy_model = PcaRdaKdeModel(k_folds=k_folds)
dummy_model.fit(train_data, train_labels)
probs = dummy_model.predict_proba(test_data)
@@ -209,7 +214,8 @@ def analyze_erp(
except Exception as e:
log.error(f"Error calculating balanced accuracy: {e}")
- save_model(model, Path(data_folder, f"model_{device_spec.content_type.lower()}_{model.auc:0.4f}.pkl"))
+ save_model(model, Path(
+ data_folder, f"model_{device_spec.content_type.lower()}_{model.auc:0.4f}.pkl"))
preferences.signal_model_directory = data_folder
if save_figures or show_figures:
@@ -259,13 +265,15 @@ def analyze_gaze(
sample_rate = gaze_data.sample_rate
flash_time = parameters.get("time_flash") # duration of each stimulus
- stim_length = parameters.get("stim_length") # number of stimuli per inquiry
+ # number of stimuli per inquiry
+ stim_length = parameters.get("stim_length")
log.info(f"Channels read from csv: {channels}")
log.info(f"Device type: {type_amp}, fs={sample_rate}")
channel_map = analysis_channels(channels, device_spec)
- channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1]
+ channels_used = [channels[i]
+ for i, keep in enumerate(channel_map) if keep == 1]
log.info(f'Channels used in analysis: {channels_used}')
data, _fs = gaze_data.by_channel()
@@ -287,10 +295,12 @@ def analyze_gaze(
)
''' Trigger_timing includes PROMPT and excludes FIXATION '''
- target_symbols = trigger_symbols[0::stim_length + 1] # target symbols are the PROMPT triggers
+ # target symbols are the PROMPT triggers
+ target_symbols = trigger_symbols[0::stim_length + 1]
# Use trigger_timing to generate time windows for each letter flashing
# Take every 10th trigger as the start point of timing.
- inq_start = trigger_timing[1::stim_length + 1] # start of each inquiry (here we jump over prompts)
+ # start of each inquiry (here we jump over prompts)
+ inq_start = trigger_timing[1::stim_length + 1]
# Extract the inquiries dictionary with keys as target symbols and values as inquiry windows:
inquiries_dict, inquiries_list, _ = model.reshaper(
@@ -304,13 +314,16 @@ def analyze_gaze(
)
# Apply preprocessing:
- inquiry_length = inquiries_list[0].shape[1] # number of time samples in each inquiry
+ # number of time samples in each inquiry
+ inquiry_length = inquiries_list[0].shape[1]
predefined_dimensions = 4 # left_x, left_y, right_x, right_y
- preprocessed_array = np.zeros((len(inquiries_list), predefined_dimensions, inquiry_length))
+ preprocessed_array = np.zeros(
+ (len(inquiries_list), predefined_dimensions, inquiry_length))
# Extract left_x, left_y, right_x, right_y for each inquiry
for j in range(len(inquiries_list)):
left_eye, right_eye, _, _, _, _ = extract_eye_info(inquiries_list[j])
- preprocessed_array[j] = np.concatenate((left_eye.T, right_eye.T,), axis=0)
+ preprocessed_array[j] = np.concatenate(
+ (left_eye.T, right_eye.T,), axis=0)
preprocessed_data = {i: [] for i in symbol_set}
for i in symbol_set:
@@ -319,8 +332,10 @@ def analyze_gaze(
continue
for j in range(len(inquiries_dict[i])):
- left_eye, right_eye, _, _, _, _ = extract_eye_info(inquiries_dict[i][j])
- preprocessed_data[i].append((np.concatenate((left_eye.T, right_eye.T), axis=0)))
+ left_eye, right_eye, _, _, _, _ = extract_eye_info(
+ inquiries_dict[i][j])
+ preprocessed_data[i].append(
+ (np.concatenate((left_eye.T, right_eye.T), axis=0)))
# Inquiries x All Dimensions (left_x, left_y, right_x, right_y) x Time
preprocessed_data[i] = np.array(preprocessed_data[i])
@@ -362,7 +377,8 @@ def analyze_gaze(
# Split the data into train and test sets & fit the model:
centralized_gaze_data = np.zeros_like(preprocessed_array)
for i, (_, sym) in enumerate(zip(preprocessed_array, target_symbols)):
- centralized_gaze_data[i] = model.subtract_mean(preprocessed_array[i], time_average[sym])
+ centralized_gaze_data[i] = model.subtract_mean(
+ preprocessed_array[i], time_average[sym])
reshaped_data = centralized_gaze_data.reshape(
(len(centralized_gaze_data), inquiry_length * predefined_dimensions))
@@ -451,7 +467,8 @@ def offline_analysis(
if spec.is_active)
active_raw_data_paths = (Path(data_folder, raw_data_filename(device_spec))
for device_spec in active_devices)
- data_file_paths = [str(path) for path in active_raw_data_paths if path.exists()]
+ data_file_paths = [str(path)
+ for path in active_raw_data_paths if path.exists()]
num_devices = len(data_file_paths)
assert num_devices >= 1 and num_devices < 3, (
@@ -488,7 +505,8 @@ def offline_analysis(
n_iterations=n_iterations,
)
- log.info(f"EEG Accuracy: {eeg_acc}, Gaze Accuracy: {gaze_acc}, Fusion Accuracy: {fusion_acc}")
+ log.info(
+ f"EEG Accuracy: {eeg_acc}, Gaze Accuracy: {gaze_acc}, Fusion Accuracy: {fusion_acc}")
# The average gaze model accuracy:
avg_testing_acc_gaze = round(np.mean(gaze_acc), 3)
@@ -556,9 +574,12 @@ def main():
"--parameters_file",
default=DEFAULT_PARAMETERS_PATH,
help="Path to the BciPy parameters file.")
- parser.add_argument("-s", "--save_figures", action="store_true", help="Save figures after training.")
- parser.add_argument("-v", "--show_figures", action="store_true", help="Show figures after training.")
- parser.add_argument("-i", "--iterations", type=int, default=10, help="Number of iterations for fusion analysis.")
+ parser.add_argument("-s", "--save_figures",
+ action="store_true", help="Save figures after training.")
+ parser.add_argument("-v", "--show_figures",
+ action="store_true", help="Show figures after training.")
+ parser.add_argument("-i", "--iterations", type=int, default=10,
+ help="Number of iterations for fusion analysis.")
parser.add_argument(
"--alert",
dest="alert",
diff --git a/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py b/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py
index 66439a427..53815126a 100644
--- a/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py
+++ b/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py
@@ -49,13 +49,15 @@ def fit(self, train_data: np.array, train_labels: np.array) -> SignalModel:
"""
model = Pipeline(
[
- ChannelWisePrincipalComponentAnalysis(n_components=self.pca_n_components, num_ch=train_data.shape[0]),
+ ChannelWisePrincipalComponentAnalysis(
+ n_components=self.pca_n_components, num_ch=train_data.shape[0]),
RegularizedDiscriminantAnalysis(),
]
)
# Find the optimal gamma + lambda values
- arg_cv = cross_validation(train_data, train_labels, model=model, k_folds=self.k_folds)
+ arg_cv = cross_validation(
+ train_data, train_labels, model=model, k_folds=self.k_folds)
# Get the AUC using those optimized gamma + lambda
rda_index = 1 # the index in the pipeline
@@ -102,7 +104,8 @@ def evaluate(self, test_data: np.array, test_labels: np.array) -> ModelEvaluatio
ModelEvaluationReport: stores AUC
"""
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.evaluate()")
+ raise SignalException(
+ "must use model.fit() before model.evaluate()")
tmp_model = Pipeline([self.model.pipeline[0], self.model.pipeline[1]])
@@ -134,18 +137,22 @@ def compute_likelihood_ratio(self, data: np.array, inquiry: List[str], symbol_se
"""
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.predict()")
+ raise SignalException(
+ "must use model.fit() before model.predict()")
# Evaluate likelihood probabilities for p(e|l=1) and p(e|l=0)
log_likelihoods = self.model.transform(data)
- subset_likelihood_ratios = np.exp(log_likelihoods[:, 1] - log_likelihoods[:, 0])
+ subset_likelihood_ratios = np.exp(
+ log_likelihoods[:, 1] - log_likelihoods[:, 0])
# Restrict multiplicative updates to a reasonable range
- subset_likelihood_ratios = np.clip(subset_likelihood_ratios, self.min, self.max)
+ subset_likelihood_ratios = np.clip(
+ subset_likelihood_ratios, self.min, self.max)
# Apply likelihood ratios to entire symbol set.
likelihood_ratios = np.ones(len(symbol_set))
for idx in range(len(subset_likelihood_ratios)):
- likelihood_ratios[symbol_set.index(inquiry[idx])] *= subset_likelihood_ratios[idx]
+ likelihood_ratios[symbol_set.index(
+ inquiry[idx])] *= subset_likelihood_ratios[idx]
return likelihood_ratios # used in multimodal update
def compute_class_probabilities(self, data: np.ndarray) -> np.ndarray:
@@ -156,7 +163,8 @@ def compute_class_probabilities(self, data: np.ndarray) -> np.ndarray:
probability for the two labels.
"""
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.predict_proba()")
+ raise SignalException(
+ "must use model.fit() before model.predict_proba()")
# Model originally produces p(eeg | label). We want p(label | eeg):
#
@@ -178,7 +186,8 @@ def evaluate_likelihood(self, data: np.ndarray) -> np.ndarray:
p(e | l=1), p(e | l=0)
"""
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.predict_proba()")
+ raise SignalException(
+ "must use model.fit() before model.predict_proba()")
log_scores_class_0 = self.model.transform(data)[:, 0]
log_scores_class_1 = self.model.transform(data)[:, 1]
@@ -191,7 +200,8 @@ def predict(self, data: np.ndarray) -> np.ndarray:
predictions (np.ndarray): shape (num_items,) - the predicted label for each item.
"""
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.predict()")
+ raise SignalException(
+ "must use model.fit() before model.predict()")
posterior = self.compute_class_probabilities(data)
predictions = np.argmax(posterior, axis=1)
@@ -205,7 +215,8 @@ def predict_proba(self, data: np.ndarray) -> np.ndarray:
probability for the two labels.
"""
if not self.ready_to_predict:
- raise SignalException("must use model.fit() before model.predict_proba()")
+ raise SignalException(
+ "must use model.fit() before model.predict_proba()")
return self.compute_class_probabilities(data)
diff --git a/bcipy/signal/model/pipeline.py b/bcipy/signal/model/pipeline.py
index f89243f66..3ac6c79e1 100644
--- a/bcipy/signal/model/pipeline.py
+++ b/bcipy/signal/model/pipeline.py
@@ -36,7 +36,8 @@ def fit(self, x, y):
y(ndarray[int]): of desired shape """
self.line_el = [x]
for i in range(len(self.pipeline) - 1):
- self.line_el.append(self.pipeline[i].fit_transform(self.line_el[i], y))
+ self.line_el.append(
+ self.pipeline[i].fit_transform(self.line_el[i], y))
self.pipeline[-1].fit(self.line_el[-1], y)
@@ -48,16 +49,18 @@ def fit_transform(self, x, y):
self.line_el = [x]
for i in range(len(self.pipeline) - 1):
- self.line_el.append(self.pipeline[i].fit_transform(self.line_el[i], y))
+ self.line_el.append(
+ self.pipeline[i].fit_transform(self.line_el[i], y))
arg = self.pipeline[-1].fit_transform(self.line_el[-1], y)
return arg
def transform(self, x):
- """ Applies transform on all functions. Prior to using transform on
+ """Applies transform on all functions. Prior to using transform on
pipeline, it should be trained.
Args:
- x(ndarray[float]): of desired shape """
+ x(ndarray[float]): of desired shape
+ """
self.line_el = [x]
for i in range(len(self.pipeline)):
self.line_el.append(self.pipeline[i].transform(self.line_el[i]))
diff --git a/bcipy/signal/model/rda_kde/rda_kde.py b/bcipy/signal/model/rda_kde/rda_kde.py
index 785f7c9dc..ab0ff82a6 100644
--- a/bcipy/signal/model/rda_kde/rda_kde.py
+++ b/bcipy/signal/model/rda_kde/rda_kde.py
@@ -43,7 +43,8 @@ def fit(self, train_data: np.array, train_labels: np.array) -> SignalModel:
model = Pipeline([MockPCA(), RegularizedDiscriminantAnalysis()])
# Find the optimal gamma + lambda values
- arg_cv = cross_validation(train_data, train_labels, model=model, k_folds=self.k_folds)
+ arg_cv = cross_validation(
+ train_data, train_labels, model=model, k_folds=self.k_folds)
# Get the AUC using those optimized gamma + lambda
rda_index = 1 # the index in the pipeline
@@ -92,7 +93,8 @@ def evaluate(self, test_data: np.array, test_labels: np.array) -> ModelEvaluatio
ModelEvaluationReport: stores AUC
"""
if not self._ready_to_predict:
- raise SignalException("must use model.fit() before model.evaluate()")
+ raise SignalException(
+ "must use model.fit() before model.evaluate()")
tmp_model = Pipeline([self.model.pipeline[0], self.model.pipeline[1]])
@@ -105,8 +107,7 @@ def evaluate(self, test_data: np.array, test_labels: np.array) -> ModelEvaluatio
return ModelEvaluationReport(auc)
def predict(self, data: np.array, inquiry: List[str], symbol_set: List[str]) -> np.array:
- """
- For each trial in `data`, compute a likelihood ratio to update that symbol's probability.
+ """For each trial in `data`, compute a likelihood ratio to update that symbol's probability.
Rather than just computing an update p(e|l=+) for the seen symbol and p(e|l=-) for all unseen symbols,
we compute a likelihood ratio p(e | l=+) / p(e | l=-) to update the seen symbol, and all other symbols
can receive a multiplicative update of 1.
@@ -124,18 +125,22 @@ def predict(self, data: np.array, inquiry: List[str], symbol_set: List[str]) ->
"""
if not self._ready_to_predict:
- raise SignalException("must use model.fit() before model.predict()")
+ raise SignalException(
+ "must use model.fit() before model.predict()")
# Evaluate likelihood probabilities for p(e|l=1) and p(e|l=0)
log_likelihoods = self.model.transform(data)
- subset_likelihood_ratios = np.exp(log_likelihoods[:, 1] - log_likelihoods[:, 0])
+ subset_likelihood_ratios = np.exp(
+ log_likelihoods[:, 1] - log_likelihoods[:, 0])
# Restrict multiplicative updates to a reasonable range
- subset_likelihood_ratios = np.clip(subset_likelihood_ratios, self.min, self.max)
+ subset_likelihood_ratios = np.clip(
+ subset_likelihood_ratios, self.min, self.max)
# Apply likelihood ratios to entire symbol set.
likelihood_ratios = np.ones(len(symbol_set))
for idx in range(len(subset_likelihood_ratios)):
- likelihood_ratios[symbol_set.index(inquiry[idx])] *= subset_likelihood_ratios[idx]
+ likelihood_ratios[symbol_set.index(
+ inquiry[idx])] *= subset_likelihood_ratios[idx]
return likelihood_ratios
def predict_proba(self, data: np.array) -> np.array:
@@ -146,7 +151,8 @@ def predict_proba(self, data: np.array) -> np.array:
probability for the two labels.
"""
if not self._ready_to_predict:
- raise SignalException("must use model.fit() before model.predict_proba()")
+ raise SignalException(
+ "must use model.fit() before model.predict_proba()")
# Model originally produces p(eeg | label). We want p(label | eeg):
#
diff --git a/bcipy/signal/model/switch_model.py b/bcipy/signal/model/switch_model.py
index d05b58ced..4a2b64bff 100644
--- a/bcipy/signal/model/switch_model.py
+++ b/bcipy/signal/model/switch_model.py
@@ -30,9 +30,7 @@ def __init__(self, error_prob: float = 0.05):
self.error_prob = error_prob
def fit(self, training_data: np.ndarray, training_labels: np.ndarray):
- """
- @override
- """
+ """@override"""
return self
def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray):
@@ -45,8 +43,7 @@ def predict(self, data: np.ndarray, inquiry: List[str],
def compute_likelihood_ratio(self, data: np.array, inquiry: List[str],
symbol_set: List[str]) -> np.array:
- """
- For each trial in `data`, compute a likelihood ratio to update that symbol's probability.
+ """For each trial in `data`, compute a likelihood ratio to update that symbol's probability.
Args:
data (np.array): button press data data a single element of 0 or 1; shape (1,)
diff --git a/bcipy/signal/process/decomposition/cwt.py b/bcipy/signal/process/decomposition/cwt.py
index 920ae5fcf..72279c036 100644
--- a/bcipy/signal/process/decomposition/cwt.py
+++ b/bcipy/signal/process/decomposition/cwt.py
@@ -26,7 +26,8 @@ def continuous_wavelet_transform(
scales = pywt.central_frequency(wavelet) * fs / np.array(freq)
all_coeffs = []
for trial in data:
- coeffs, _ = pywt.cwt(trial, scales, wavelet) # shape == (scales, channels, time)
+ # shape == (scales, channels, time)
+ coeffs, _ = pywt.cwt(trial, scales, wavelet)
all_coeffs.append(coeffs)
final_data = np.stack(all_coeffs)
diff --git a/bcipy/signal/process/extract_gaze.py b/bcipy/signal/process/extract_gaze.py
index 77183565c..8516ce7ad 100644
--- a/bcipy/signal/process/extract_gaze.py
+++ b/bcipy/signal/process/extract_gaze.py
@@ -4,7 +4,7 @@
def extract_eye_info(data):
- """"Rearrange the dimensions of gaze inquiry data and reshape it to num_channels x num_samples
+ """Rearrange the dimensions of gaze inquiry data and reshape it to num_channels x num_samples
Extract Left and Right Eye info from data. Remove all blinks, do necessary preprocessing.
The data is extracted according to the channel map:
['device_ts, 'system_ts', 'left_x', 'left_y', 'left_pupil', 'right_x', 'right_y', 'right_pupil']
@@ -16,7 +16,6 @@ def extract_eye_info(data):
left_eye (np.ndarray), left_pupil (List(float))
right_eye (np.ndarray), right_pupil (List(float))
"""
-
# Extract samples from channels
lx = data[2, :]
ly = data[3, :]
diff --git a/bcipy/signal/process/filter.py b/bcipy/signal/process/filter.py
index daea1c95d..b58ad2100 100644
--- a/bcipy/signal/process/filter.py
+++ b/bcipy/signal/process/filter.py
@@ -21,7 +21,8 @@ class Bandpass:
def __init__(self, lo, hi, sample_rate_hz, order=5):
nyq = 0.5 * sample_rate_hz
lo, hi = lo / nyq, hi / nyq
- self.sos = butter(order, [lo, hi], analog=False, btype="band", output="sos")
+ self.sos = butter(order, [lo, hi], analog=False,
+ btype="band", output="sos")
def __call__(self, data: np.ndarray, fs: int) -> Tuple[np.ndarray, int]:
return sosfiltfilt(self.sos, data), fs
@@ -38,7 +39,9 @@ def filter_inquiries(inquiries: np.ndarray, transform, sample_rate: int) -> Tupl
old_shape = inquiries.shape
# (Channels*Inquiry, Samples)
inq_flatten = inquiries.reshape(-1, old_shape[-1])
- inq_flatten_filtered, transformed_sample_rate = transform(inq_flatten, sample_rate)
+ inq_flatten_filtered, transformed_sample_rate = transform(
+ inq_flatten, sample_rate)
# (Channels, Inquiries, Samples)
- inquiries = inq_flatten_filtered.reshape(*old_shape[:2], inq_flatten_filtered.shape[-1])
+ inquiries = inq_flatten_filtered.reshape(
+ *old_shape[:2], inq_flatten_filtered.shape[-1])
return inquiries, transformed_sample_rate
diff --git a/bcipy/signal/tests/evaluate/test_artifact.py b/bcipy/signal/tests/evaluate/test_artifact.py
index ec201561e..f89236a70 100644
--- a/bcipy/signal/tests/evaluate/test_artifact.py
+++ b/bcipy/signal/tests/evaluate/test_artifact.py
@@ -43,7 +43,8 @@ def tearDown(self) -> None:
def test_artifact_detection_init(self):
"""Test the ArtifactDetection class."""
- ar = ArtifactDetection(raw_data=self.raw_data, parameters=self.parameters, device_spec=self.device_spec)
+ ar = ArtifactDetection(
+ raw_data=self.raw_data, parameters=self.parameters, device_spec=self.device_spec)
self.assertIsInstance(ar, ArtifactDetection)
self.assertFalse(ar.analysis_done)
self.assertIsNone(ar.dropped)
@@ -101,7 +102,8 @@ def test_artifact_detection_detect_artifacts(self):
device_spec=self.device_spec)
labels = [mock()]
expected_label_response = f'{len(labels)} artifacts found in the data.'
- when(ar).label_artifacts(extra_labels=ar.session_triggers).thenReturn(labels)
+ when(ar).label_artifacts(
+ extra_labels=ar.session_triggers).thenReturn(labels)
response_labels, response_dropped = ar.detect_artifacts()
self.assertEqual(response_labels, expected_label_response)
self.assertEqual(response_dropped, 0)
@@ -122,9 +124,12 @@ def test_artifact_type(self):
def test_default_artifact_parameters(self):
"""Test the DefaultArtifactParameters class."""
- self.assertEqual(DefaultArtifactParameters.EOG_THRESHOLD.value, 5.5e-05)
- self.assertEqual(DefaultArtifactParameters.VOlTAGE_LABEL_DURATION.value, 0.25)
- self.assertEqual(DefaultArtifactParameters.ARTIFACT_LABELLED_FILENAME.value, 'artifacts.fif')
+ self.assertEqual(
+ DefaultArtifactParameters.EOG_THRESHOLD.value, 5.5e-05)
+ self.assertEqual(
+ DefaultArtifactParameters.VOlTAGE_LABEL_DURATION.value, 0.25)
+ self.assertEqual(
+ DefaultArtifactParameters.ARTIFACT_LABELLED_FILENAME.value, 'artifacts.fif')
if __name__ == '__main__':
diff --git a/bcipy/signal/tests/model/pca_rda_kde/test_pca_rda_kde.py b/bcipy/signal/tests/model/pca_rda_kde/test_pca_rda_kde.py
index dabca5f2a..0def715e4 100644
--- a/bcipy/signal/tests/model/pca_rda_kde/test_pca_rda_kde.py
+++ b/bcipy/signal/tests/model/pca_rda_kde/test_pca_rda_kde.py
@@ -20,7 +20,8 @@
ChannelWisePrincipalComponentAnalysis
from bcipy.signal.model.pipeline import Pipeline
-expected_output_folder = Path(__file__).absolute().parent.parent / "unit_test_expected_output"
+expected_output_folder = Path(__file__).absolute(
+).parent.parent / "unit_test_expected_output"
class ModelSetup(unittest.TestCase):
@@ -37,8 +38,10 @@ def setUpClass(cls):
# Generate Gaussian random data
cls.pos_mean, cls.pos_std = 0, 0.5
cls.neg_mean, cls.neg_std = 1, 0.5
- x_pos = cls.pos_mean + cls.pos_std * np.random.randn(cls.num_channel, cls.num_x_pos, cls.dim_x)
- x_neg = cls.neg_mean + cls.neg_std * np.random.randn(cls.num_channel, cls.num_x_neg, cls.dim_x)
+ x_pos = cls.pos_mean + cls.pos_std * \
+ np.random.randn(cls.num_channel, cls.num_x_pos, cls.dim_x)
+ x_neg = cls.neg_mean + cls.neg_std * \
+ np.random.randn(cls.num_channel, cls.num_x_neg, cls.dim_x)
y_pos = np.ones(cls.num_x_pos)
y_neg = np.zeros(cls.num_x_neg)
@@ -76,7 +79,8 @@ def setUp(self):
def test_pca(self):
# .fit() then .transform() should match .fit_transform()
- pca = ChannelWisePrincipalComponentAnalysis(n_components=0.9, num_ch=self.num_channel)
+ pca = ChannelWisePrincipalComponentAnalysis(
+ n_components=0.9, num_ch=self.num_channel)
pca.fit(self.x)
x_reduced = pca.transform(self.x)
x_reduced_2 = pca.fit_transform(self.x)
@@ -98,7 +102,8 @@ def test_kde_plot(self):
"""
# generate some dummy data
n = 100
- x = np.concatenate((np.random.normal(0, 1, int(0.3 * n)), np.random.normal(5, 1, int(0.7 * n))))[:, np.newaxis]
+ x = np.concatenate((np.random.normal(0, 1, int(0.3 * n)),
+ np.random.normal(5, 1, int(0.7 * n))))[:, np.newaxis]
# append 0 label to all data as we are interested in a single class case
y = np.zeros(x.shape)
@@ -107,17 +112,20 @@ def test_kde_plot(self):
x_plot = np.linspace(-5, 10, 1000)[:, np.newaxis]
# generate a dummy density function to sample data from
- true_dens = 0.3 * norm(0, 1).pdf(x_plot[:, 0]) + 0.7 * norm(5, 1).pdf(x_plot[:, 0])
+ true_dens = 0.3 * \
+ norm(0, 1).pdf(x_plot[:, 0]) + 0.7 * norm(5, 1).pdf(x_plot[:, 0])
fig, ax = plt.subplots()
- ax.fill(x_plot[:, 0], true_dens, fc="black", alpha=0.2, label="input distribution")
+ ax.fill(x_plot[:, 0], true_dens, fc="black",
+ alpha=0.2, label="input distribution")
# try different kernels and show how the look like
for kernel in ["gaussian", "tophat", "epanechnikov"]:
kde = KernelDensityEstimate(kernel=kernel, scores=x, num_cls=1)
kde.fit(x, y)
log_dens = kde.list_den_est[0].score_samples(x_plot)
- ax.plot(x_plot[:, 0], np.exp(log_dens), "-", label=f"kernel = '{kernel}'")
+ ax.plot(x_plot[:, 0], np.exp(log_dens),
+ "-", label=f"kernel = '{kernel}'")
ax.plot(x[:, 0], -0.005 - 0.01 * np.random.random(x.shape[0]), "+k")
@@ -126,7 +134,8 @@ def test_kde_plot(self):
return fig
def test_kde_values(self):
- pca = ChannelWisePrincipalComponentAnalysis(n_components=0.9, num_ch=self.num_channel)
+ pca = ChannelWisePrincipalComponentAnalysis(
+ n_components=0.9, num_ch=self.num_channel)
rda = RegularizedDiscriminantAnalysis()
kde = KernelDensityEstimate()
@@ -139,7 +148,8 @@ def test_kde_values(self):
self.assertTrue(np.allclose(z, z_2))
# output values should be correct
- expected = np.load(expected_output_folder / "test_kde_values.expected.npy")
+ expected = np.load(expected_output_folder /
+ "test_kde_values.expected.npy")
self.assertTrue(np.allclose(z, expected))
def test_cv(self):
@@ -150,7 +160,8 @@ def test_cv(self):
before fitting it - it is not clear how sensitive this test is to changes in the code
or input data, so this may be a weak test of cross_validation().
"""
- pca = ChannelWisePrincipalComponentAnalysis(n_components=0.9, num_ch=self.num_channel)
+ pca = ChannelWisePrincipalComponentAnalysis(
+ n_components=0.9, num_ch=self.num_channel)
rda = RegularizedDiscriminantAnalysis()
pipeline = Pipeline([pca, rda])
@@ -160,7 +171,8 @@ def test_cv(self):
self.assertAlmostEqual(gam, 0.1)
def test_rda(self):
- pca = ChannelWisePrincipalComponentAnalysis(n_components=0.9, num_ch=self.num_channel)
+ pca = ChannelWisePrincipalComponentAnalysis(
+ n_components=0.9, num_ch=self.num_channel)
rda = RegularizedDiscriminantAnalysis()
pipeline = Pipeline([pca, rda])
@@ -205,13 +217,17 @@ def test_fit_compute_likelihood_ratio(self):
num_x_p = 1
num_x_n = 9
- x_test_pos = self.pos_mean + self.pos_std * np.random.randn(self.num_channel, num_x_p, self.dim_x)
- x_test_neg = self.neg_mean + self.neg_std * np.random.randn(self.num_channel, num_x_n, self.dim_x)
- x_test = np.concatenate((x_test_pos, x_test_neg), 1) # Target letter is first
+ x_test_pos = self.pos_mean + self.pos_std * \
+ np.random.randn(self.num_channel, num_x_p, self.dim_x)
+ x_test_neg = self.neg_mean + self.neg_std * \
+ np.random.randn(self.num_channel, num_x_n, self.dim_x)
+ # Target letter is first
+ x_test = np.concatenate((x_test_pos, x_test_neg), 1)
letters = alp[10: 10 + num_x_p + num_x_n] # Target letter is K
- lik_r = self.model.compute_likelihood_ratio(data=x_test, inquiry=letters, symbol_set=alp)
+ lik_r = self.model.compute_likelihood_ratio(
+ data=x_test, inquiry=letters, symbol_set=alp)
fig, ax = plt.subplots()
ax.plot(np.arange(len(alp)), lik_r, "ro")
ax.set_xticks(np.arange(len(alp)))
@@ -228,7 +244,8 @@ def test_save_load(self):
symbol_set = alphabet()
inquiry = symbol_set[:n_trial]
data = np.random.randn(self.num_channel, n_trial, self.dim_x)
- output_before = self.model.compute_likelihood_ratio(data=data, inquiry=inquiry, symbol_set=symbol_set)
+ output_before = self.model.compute_likelihood_ratio(
+ data=data, inquiry=inquiry, symbol_set=symbol_set)
checkpoint_path = self.tmp_dir / "model.pkl"
save_model(self.model, checkpoint_path)
@@ -237,19 +254,22 @@ def test_save_load(self):
self.assertEqual(1, len(loaded_models))
other_model = loaded_models[0]
self.assertEqual(self.model.k_folds, other_model.k_folds)
- output_after = other_model.compute_likelihood_ratio(data=data, inquiry=inquiry, symbol_set=symbol_set)
+ output_after = other_model.compute_likelihood_ratio(
+ data=data, inquiry=inquiry, symbol_set=symbol_set)
self.assertTrue(np.allclose(output_before, output_after))
try:
other_model.predict_proba(self.x)
except Exception:
- pytest.fail("Should be able to compute predict_proba after loading a model")
+ pytest.fail(
+ "Should be able to compute predict_proba after loading a model")
def test_predict_before_fit(self):
model = PcaRdaKdeModel(k_folds=10)
with self.assertRaises(SignalException):
- model.compute_likelihood_ratio(self.x, inquiry=["A"], symbol_set=alphabet())
+ model.compute_likelihood_ratio(
+ self.x, inquiry=["A"], symbol_set=alphabet())
def test_evaluate_before_fit(self):
model = PcaRdaKdeModel(k_folds=10)
diff --git a/bcipy/signal/tests/model/rda_kde/test_rda_kde.py b/bcipy/signal/tests/model/rda_kde/test_rda_kde.py
index b06cf974e..c2b7b6177 100644
--- a/bcipy/signal/tests/model/rda_kde/test_rda_kde.py
+++ b/bcipy/signal/tests/model/rda_kde/test_rda_kde.py
@@ -17,7 +17,8 @@
from bcipy.signal.model.dimensionality_reduction import MockPCA
from bcipy.signal.model.pipeline import Pipeline
-expected_output_folder = Path(__file__).absolute().parent.parent / "unit_test_expected_output"
+expected_output_folder = Path(__file__).absolute(
+).parent.parent / "unit_test_expected_output"
class ModelSetup(unittest.TestCase):
@@ -34,8 +35,10 @@ def setUpClass(cls):
# Generate Gaussian random data
cls.pos_mean, cls.pos_std = 0, 0.5
cls.neg_mean, cls.neg_std = 1, 0.5
- x_pos = cls.pos_mean + cls.pos_std * np.random.randn(cls.num_channel, cls.num_x_pos, cls.dim_x)
- x_neg = cls.neg_mean + cls.neg_std * np.random.randn(cls.num_channel, cls.num_x_neg, cls.dim_x)
+ x_pos = cls.pos_mean + cls.pos_std * \
+ np.random.randn(cls.num_channel, cls.num_x_pos, cls.dim_x)
+ x_neg = cls.neg_mean + cls.neg_std * \
+ np.random.randn(cls.num_channel, cls.num_x_neg, cls.dim_x)
y_pos = np.ones(cls.num_x_pos)
y_neg = np.zeros(cls.num_x_neg)
@@ -78,7 +81,8 @@ def setUp(self):
def test_kde_plot(self):
# generate some dummy data
n = 100
- x = np.concatenate((np.random.normal(0, 1, int(0.3 * n)), np.random.normal(5, 1, int(0.7 * n))))[:, np.newaxis]
+ x = np.concatenate((np.random.normal(0, 1, int(0.3 * n)),
+ np.random.normal(5, 1, int(0.7 * n))))[:, np.newaxis]
# append 0 label to all data as we are interested in a single class case
y = np.zeros(x.shape)
@@ -87,17 +91,20 @@ def test_kde_plot(self):
x_plot = np.linspace(-5, 10, 1000)[:, np.newaxis]
# generate a dummy density function to sample data from
- true_dens = 0.3 * norm(0, 1).pdf(x_plot[:, 0]) + 0.7 * norm(5, 1).pdf(x_plot[:, 0])
+ true_dens = 0.3 * \
+ norm(0, 1).pdf(x_plot[:, 0]) + 0.7 * norm(5, 1).pdf(x_plot[:, 0])
fig, ax = plt.subplots()
- ax.fill(x_plot[:, 0], true_dens, fc="black", alpha=0.2, label="input distribution")
+ ax.fill(x_plot[:, 0], true_dens, fc="black",
+ alpha=0.2, label="input distribution")
# try different kernels and show how the look like
for kernel in ["gaussian", "tophat", "epanechnikov"]:
kde = KernelDensityEstimate(kernel=kernel, scores=x, num_cls=1)
kde.fit(x, y)
log_dens = kde.list_den_est[0].score_samples(x_plot)
- ax.plot(x_plot[:, 0], np.exp(log_dens), "-", label=f"kernel = '{kernel}'")
+ ax.plot(x_plot[:, 0], np.exp(log_dens),
+ "-", label=f"kernel = '{kernel}'")
ax.plot(x[:, 0], -0.005 - 0.01 * np.random.random(x.shape[0]), "+k")
@@ -177,13 +184,17 @@ def test_fit_predict(self):
num_x_p = 1
num_x_n = 9
- x_test_pos = self.pos_mean + self.pos_std * np.random.randn(self.num_channel, num_x_p, self.dim_x)
- x_test_neg = self.neg_mean + self.neg_std * np.random.randn(self.num_channel, num_x_n, self.dim_x)
- x_test = np.concatenate((x_test_pos, x_test_neg), 1) # Target letter is first
+ x_test_pos = self.pos_mean + self.pos_std * \
+ np.random.randn(self.num_channel, num_x_p, self.dim_x)
+ x_test_neg = self.neg_mean + self.neg_std * \
+ np.random.randn(self.num_channel, num_x_n, self.dim_x)
+ # Target letter is first
+ x_test = np.concatenate((x_test_pos, x_test_neg), 1)
letters = alp[10: 10 + num_x_p + num_x_n] # Target letter is K
- lik_r = self.model.predict(data=x_test, inquiry=letters, symbol_set=alp)
+ lik_r = self.model.predict(
+ data=x_test, inquiry=letters, symbol_set=alp)
fig, ax = plt.subplots()
ax.plot(np.arange(len(alp)), lik_r, "ro")
ax.set_xticks(np.arange(len(alp)))
@@ -200,13 +211,15 @@ def test_save_load(self):
symbol_set = alphabet()
inquiry = symbol_set[:n_trial]
data = np.random.randn(self.num_channel, n_trial, self.dim_x)
- output_before = self.model.predict(data=data, inquiry=inquiry, symbol_set=symbol_set)
+ output_before = self.model.predict(
+ data=data, inquiry=inquiry, symbol_set=symbol_set)
checkpoint_path = self.tmp_dir / "model.pkl"
self.model.save(checkpoint_path)
other_model = RdaKdeModel(k_folds=self.model.k_folds)
other_model.load(checkpoint_path)
- output_after = other_model.predict(data=data, inquiry=inquiry, symbol_set=symbol_set)
+ output_after = other_model.predict(
+ data=data, inquiry=inquiry, symbol_set=symbol_set)
self.assertTrue(np.allclose(output_before, output_after))
diff --git a/bcipy/signal/tests/model/test_offline_analysis.py b/bcipy/signal/tests/model/test_offline_analysis.py
index 3cc21d78f..bad8ad5cf 100644
--- a/bcipy/signal/tests/model/test_offline_analysis.py
+++ b/bcipy/signal/tests/model/test_offline_analysis.py
@@ -18,7 +18,8 @@
pwd = Path(__file__).absolute().parent
input_folder = pwd / "integration_test_input"
-expected_output_folder = pwd / "integration_test_expected_output" # global for the purpose of pytest-mpl decorator
+# global for the purpose of pytest-mpl decorator
+expected_output_folder = pwd / "integration_test_expected_output"
@pytest.mark.slow
@@ -48,10 +49,13 @@ def setUpClass(cls):
shutil.copyfileobj(f_source, f_dest)
# copy the other required inputs into tmp_dir
- shutil.copyfile(eeg_input_folder / TRIGGER_FILENAME, cls.tmp_dir / TRIGGER_FILENAME)
- shutil.copyfile(eeg_input_folder / DEFAULT_DEVICE_SPEC_FILENAME, cls.tmp_dir / DEFAULT_DEVICE_SPEC_FILENAME)
+ shutil.copyfile(eeg_input_folder / TRIGGER_FILENAME,
+ cls.tmp_dir / TRIGGER_FILENAME)
+ shutil.copyfile(eeg_input_folder / DEFAULT_DEVICE_SPEC_FILENAME,
+ cls.tmp_dir / DEFAULT_DEVICE_SPEC_FILENAME)
- params_path = pwd.parent.parent.parent / "parameters" / DEFAULT_PARAMETERS_FILENAME
+ params_path = pwd.parent.parent.parent / \
+ "parameters" / DEFAULT_PARAMETERS_FILENAME
cls.parameters = load_json_parameters(params_path, value_cast=True)
models = offline_analysis(
str(cls.tmp_dir),
@@ -74,8 +78,10 @@ def get_auc(model_filename):
return float(match[1])
def test_model_auc(self):
- expected_auc = self.get_auc(list(expected_output_folder.glob("model_eeg_*.pkl"))[0].name)
- found_auc = self.get_auc(list(self.tmp_dir.glob("model_eeg_*.pkl"))[0].name)
+ expected_auc = self.get_auc(
+ list(expected_output_folder.glob("model_eeg_*.pkl"))[0].name)
+ found_auc = self.get_auc(
+ list(self.tmp_dir.glob("model_eeg_*.pkl"))[0].name)
self.assertAlmostEqual(expected_auc, found_auc, delta=0.005)
def test_model_metadata_loads(self):
@@ -111,14 +117,16 @@ def setUpClass(cls):
shutil.copyfileobj(f_source, f_dest)
# copy the other required inputs into tmp_dir
- shutil.copyfile(eye_tracking_input_folder / TRIGGER_FILENAME, cls.tmp_dir / TRIGGER_FILENAME)
+ shutil.copyfile(eye_tracking_input_folder /
+ TRIGGER_FILENAME, cls.tmp_dir / TRIGGER_FILENAME)
shutil.copyfile(
eye_tracking_input_folder /
DEFAULT_DEVICE_SPEC_FILENAME,
cls.tmp_dir /
DEFAULT_DEVICE_SPEC_FILENAME)
- params_path = pwd.parent.parent.parent / "parameters" / DEFAULT_PARAMETERS_FILENAME
+ params_path = pwd.parent.parent.parent / \
+ "parameters" / DEFAULT_PARAMETERS_FILENAME
cls.parameters = load_json_parameters(params_path, value_cast=True)
models = offline_analysis(
str(cls.tmp_dir),
@@ -144,8 +152,10 @@ def get_acc(model_filename):
return float(match[1])
def test_model_acc(self):
- expected_auc = self.get_acc(list(expected_output_folder.glob("model_eyetracker_*.pkl"))[0].name)
- found_auc = self.get_acc(list(self.tmp_dir.glob("model_eyetracker_*.pkl"))[0].name)
+ expected_auc = self.get_acc(
+ list(expected_output_folder.glob("model_eyetracker_*.pkl"))[0].name)
+ found_auc = self.get_acc(
+ list(self.tmp_dir.glob("model_eyetracker_*.pkl"))[0].name)
self.assertAlmostEqual(expected_auc, found_auc, delta=0.005)
@@ -184,10 +194,13 @@ def setUpClass(cls):
shutil.copyfileobj(f_source, f_dest)
# copy the other required inputs into tmp_dir
- shutil.copyfile(et_input_folder / TRIGGER_FILENAME, cls.tmp_dir / TRIGGER_FILENAME)
- shutil.copyfile(fusion_input_folder / DEFAULT_DEVICE_SPEC_FILENAME, cls.tmp_dir / DEFAULT_DEVICE_SPEC_FILENAME)
+ shutil.copyfile(et_input_folder / TRIGGER_FILENAME,
+ cls.tmp_dir / TRIGGER_FILENAME)
+ shutil.copyfile(fusion_input_folder / DEFAULT_DEVICE_SPEC_FILENAME,
+ cls.tmp_dir / DEFAULT_DEVICE_SPEC_FILENAME)
- params_path = pwd.parent.parent.parent / "parameters" / DEFAULT_PARAMETERS_FILENAME
+ params_path = pwd.parent.parent.parent / \
+ "parameters" / DEFAULT_PARAMETERS_FILENAME
cls.parameters = load_json_parameters(params_path, value_cast=True)
models = offline_analysis(
str(cls.tmp_dir),
@@ -222,13 +235,17 @@ def get_auc(model_filename):
return float(match[1])
def test_model_acc(self):
- expected_auc = self.get_acc(list(self.output_folder.glob("model_eyetracker_*.pkl"))[0].name)
- found_auc = self.get_acc(list(self.tmp_dir.glob("model_eyetracker_*.pkl"))[0].name)
+ expected_auc = self.get_acc(
+ list(self.output_folder.glob("model_eyetracker_*.pkl"))[0].name)
+ found_auc = self.get_acc(
+ list(self.tmp_dir.glob("model_eyetracker_*.pkl"))[0].name)
self.assertAlmostEqual(expected_auc, found_auc, delta=0.005)
def test_model_auc(self):
- expected_auc = self.get_auc(list(self.output_folder.glob("model_eeg_*.pkl"))[0].name)
- found_auc = self.get_auc(list(self.tmp_dir.glob("model_eeg_*.pkl"))[0].name)
+ expected_auc = self.get_auc(
+ list(self.output_folder.glob("model_eeg_*.pkl"))[0].name)
+ found_auc = self.get_auc(
+ list(self.tmp_dir.glob("model_eeg_*.pkl"))[0].name)
self.assertAlmostEqual(expected_auc, found_auc, delta=0.005)
diff --git a/bcipy/simulator/README.md b/bcipy/simulator/README.md
index 452ef1f9d..fb5c5ed24 100644
--- a/bcipy/simulator/README.md
+++ b/bcipy/simulator/README.md
@@ -1,4 +1,4 @@
-# RSVP Simulator
+# BciPy Simulator
## Overview
@@ -33,14 +33,15 @@ optional arguments:
```
For example,
-`$ python bcipy/simulator -d my_data_folder/ -p my_parameters.json -m my_models/ -n 5`
+
+`$ bcipy-sim -d my_data_folder/ -p my_parameters.json -m my_models/ -n 5`
#### Program Args
- `i` : Interactive command line interface. Provide this flag by itself to be prompted for each parameter.
- `gui`: A graphical user interface for configuring a simulation. This mode will output the command line arguments which can be used to repeat the simulation.
-- `d` : Raw data folders to be processed. One ore more values can be provided. Each session data folder should contain
- _raw_data.csv_, _triggers.txt_, _parameters.json_. These files will be used to construct a data pool from which simulator will sample EEG and other device responses. The parameters file in each data folder will be used to check compatibility with the simulation/model parameters.
+- `d` : Raw data folders to be processed. Data folders should contain EEG responses to Copy Phrase tasks. Each session data folder should contain
+ _raw_data.csv_, _triggers.txt_, _parameters.json_.These files will be used to construct a data pool from which simulator will sample EEG. The parameters file in each data folder will be used to check compatibility with the simulation/model parameters.
- `p` : path to the parameters.json file used to run the simulation. These parameters will be applied to
all raw_data files when loading. This file can specify various aspects of the simulation, including the language model to be used, the text to be spelled, etc. Timing-related parameters should generally match the parameters file used for training the signal model(s).
- `m`: Path to a pickled (.pkl) signal model. One or more models can be provided.
@@ -52,10 +53,12 @@ For example,
#### Sim Output Details
-Output folders are generally located in the `data/simulator` directory, but can be configured per simulation. Each simulation will create a new directory. The directory name will be prefixed with `SIM` and will include the current date and time.
+Output folders are generally located in the `data/simulator` directory, but can be configured per simulation. Each simulation will create a new directory. The directory name will be prefixed with `SIM` and will include the current date and time (E.G -- "SIM_%m-%d-%Y_%H_%M_%S")
+
+At the top level of the output directory, the following files are created:
- `parameters.json` captures params used for the simulation.
-- `sim.log` is a log file for the simulation; metrics will be output here.
+- `sim.log` is a log file for the overall simulation; metrics will be output here.
- `summary_data.json` summarizes session data from each of the runs into a single data structure.
- `metrics.png` boxplots for several metrics summarizing all simulation runs.
@@ -126,7 +129,6 @@ optional arguments:
Sim output path
```
-
## Current Limitations
- Only one sampler maybe provided for all devices. Ideally we should support a different sampling strategy for each device.
@@ -148,47 +150,49 @@ The `switch_data_processor` and `switch_model` are used to demonstrate a multimo
1. Ensure that the devices.json file has an entry for a switch
- ```
- {
- "name": "Switch",
- "content_type": "MARKERS",
- "channels": [
- { "name": "Marker", "label": "Marker" }
- ],
- "sample_rate": 0.0,
- "description": "Switch used for button press inputs",
- "excluded_from_analysis": [],
- "status": "active",
- "static_offset": 0.0
- }
- ```
+```json
+{
+ "name": "Switch",
+ "content_type": "MARKERS",
+ "channels": [
+ { "name": "Marker", "label": "Marker" }
+ ],
+ "sample_rate": 0.0,
+ "description": "Switch used for button press inputs",
+ "excluded_from_analysis": [],
+ "status": "active",
+ "static_offset": 0.0
+}
+```
-2. Ensure that the switch signal model can be loaded or create a switch signal model. To create a new one:
+1. Ensure that the switch signal model can be loaded or create a switch signal model. To create a new one:
- ```
- from pathlib import Path
- from bcipy.acquisition.devices import preconfigured_device
- from bcipy.io.save import save_model
- from bcipy.signal.model.base_model import SignalModelMetadata
- from bcipy.signal.model.switch_model import SwitchModel
-
- dirname = "" # TODO: enter the directory
- model = SwitchModel()
- # name should match devices.json spec. Alternatively, use bcipy.acquisition.datastream.mock.switch.switch_device()
- device = preconfigured_device("Switch")
- model.metadata = SignalModelMetadata(device_spec=device, evidence_type="BTN", transform=None)
- save_model(model, Path(dirname, "switch_model.pkl"))
- ```
+```python
+from pathlib import Path
+from bcipy.acquisition.devices import preconfigured_device
+from bcipy.io.save import save_model
+from bcipy.signal.model.base_model import SignalModelMetadata
+from bcipy.signal.model.switch_model import SwitchModel
+
+dirname = "" # TODO: enter the directory
+model = SwitchModel()
-3. Set the appropriate simulation parameters in the parameters.json file.
+# name should match devices.json spec. Alternatively, use bcipy.acquisition.datastream.mock.switch.switch_device()
+
+device = preconfigured_device("Switch")
+model.metadata = SignalModelMetadata(device_spec=device, evidence_type="BTN", transform=None)
+save_model(model, Path(dirname, "switch_model.pkl"))
+```
+
+1. Set the appropriate simulation parameters in the parameters.json file.
- set the `acq_mode` parameter to 'EEG+MARKERS'.
- ensure that `preview_inquiry_progress_method` parameter is set to '1' or '2'.
- You may also want to set the `summarize_session` parameter to `true` to see how the evidences get combined during decision-making.
-4. Ensure that the data directories have a raw data file (csv) for markers in addition to the EEG data. If your data does not have marker data, you can extract this from the triggers.txt file using the script `bcipy.simulator.util.generate_marker_data`. If the task was run with Inquiry Preview, the script can use the button press events recoreded in the trigger file. Otherwise you can use the `--mock` flag along with a parameters file to mock what a raw data file would look like if the user pressed the button according to the configured button press mode.
+2. Ensure that the data directories have a raw data file (csv) for markers in addition to the EEG data. If your data does not have marker data, you can extract this from the triggers.txt file using the script `bcipy.simulator.util.generate_marker_data`. If the task was run with Inquiry Preview, the script can use the button press events recoreded in the trigger file. Otherwise you can use the `--mock` flag along with a parameters file to mock what a raw data file would look like if the user pressed the button according to the configured button press mode.
- ```
+ ```bash
$ python -m bcipy.simulator.util.generate_marker_data -h
usage: generate_marker_data.py [-h] [-m] [-p PARAMETERS] data_folder
@@ -204,7 +208,7 @@ The `switch_data_processor` and `switch_model` are used to demonstrate a multimo
Optional parameters file to use when mocking data.
```
-5. Run a simulation.
+3. Run a simulation.
- Set the simulation parameters for both the EEG and the Button models (.pkl files).
- Use the InquirySampler
@@ -230,4 +234,3 @@ Note that the progress method (`preview_inquiry_progress_method` parameter) does
- A `preview_inquiry_progress_method` of 0 is currently not supported and an exception will be thrown. Ideally, all inquiries should get an evidence value of 1.0 (no change) with this mode.
- Button evidence only works correctly with the InquirySampler. This is due to all trials in the same inquiry receiving the same value.
-
diff --git a/bcipy/simulator/data/data_engine.py b/bcipy/simulator/data/data_engine.py
index a9635f243..23ab816cc 100644
--- a/bcipy/simulator/data/data_engine.py
+++ b/bcipy/simulator/data/data_engine.py
@@ -31,7 +31,8 @@ def is_valid(self) -> bool:
origin = get_origin(field_type)
if origin:
options = get_args(field_type)
- is_correct_type = any(isinstance(self.value, ftype) for ftype in options)
+ is_correct_type = any(isinstance(self.value, ftype)
+ for ftype in options)
else:
is_correct_type = isinstance(self.value, field_type)
@@ -46,7 +47,8 @@ def valid_operators(self) -> List[str]:
class DataEngine(ABC):
"""Abstract class for an object that loads data from one or more sources,
processes the data using a provided processor, and provides an interface
- for querying the processed data."""
+ for querying the processed data.
+ """
def load(self):
"""Load data from sources."""
diff --git a/bcipy/simulator/data/data_process.py b/bcipy/simulator/data/data_process.py
index efee02c72..85bbe2a46 100644
--- a/bcipy/simulator/data/data_process.py
+++ b/bcipy/simulator/data/data_process.py
@@ -1,6 +1,7 @@
"""This module defines functionality related to pre-processing simulation data.
Processed data can be subsequently sampled and provided to a SignalModel
-for classification."""
+for classification.
+"""
import logging as logger
from abc import abstractmethod
@@ -73,7 +74,8 @@ def load_device_data(data_folder: str,
class DecodedTriggers(NamedTuple):
"""Extracted properties after decoding the triggers.txt file and applying
- the necessary offsets and corrections."""
+ the necessary offsets and corrections.
+ """
targetness: List[str] # TriggerType
times: List[float]
symbols: List[str] # symbol
@@ -93,6 +95,7 @@ def triggers(self) -> List[Trigger]:
@dataclass()
class ExtractedExperimentData:
"""Data from an acquisition device after reshaping and filtering."""
+
source_dir: str
inquiries: np.ndarray
trials: np.ndarray
@@ -138,18 +141,17 @@ class TimingParams(NamedTuple):
@property
def trials_per_inquiry(self) -> int:
- """Alias for stim_length"""
+ """Alias for stim_length."""
return self.stim_length
@property
def buffer(self) -> float:
- """The task buffer length defines the min time between two inquiries
- We use half of that time here to buffer during transforms"""
+ """The task buffer length defines the min time between two inquiries. We use half of that time here to buffer during transforms."""
return self.task_buffer_length / 2
@property
def window_length(self) -> float:
- """window (in seconds) of data collection after each stimulus presentation"""
+ """window (in seconds) of data collection after each stimulus presentation."""
start, end = self.trial_window
return end - start
@@ -222,7 +224,7 @@ def consumes(self) -> ContentType:
@property
def produces(self) -> EvidenceType:
- """Type of evidence that is output"""
+ """Type of evidence that is output."""
raise NotImplementedError
@property
@@ -237,12 +239,13 @@ def model_device(self) -> DeviceSpec:
@property
def reshaper(self):
- """data reshaper"""
+ """data reshaper."""
return self.model.reshaper
def check_model_compatibility(self, model: SignalModel) -> None:
"""Check that the given model is compatible with this processor.
- Checked on initialization."""
+ Checked on initialization.
+ """
assert model.metadata, "Metadata missing from signal model."
assert ContentType(
model.metadata.device_spec.content_type
@@ -258,7 +261,6 @@ def check_data_compatibility(self, data_device: DeviceSpec,
data_timing_params):
raise IncompatibleParameters(
"Timing parameters are not compatible")
-
if data_device.static_offset == devices.DEFAULT_STATIC_OFFSET:
log.warning(' '.join([
f"Using the default static offset [{devices.DEFAULT_STATIC_OFFSET}] for {data_device.name}.",
@@ -309,7 +311,7 @@ def process(self, data_folder: str,
def load_device_data(self, data_folder: str,
parameters: Parameters) -> Tuple[RawData, DeviceSpec]:
- """Load the device data"""
+ """Load the device data."""
return load_device_data(data_folder, self.content_type.name)
def decode_triggers(self,
@@ -351,26 +353,26 @@ def extract_trials(self, filtered_reshaped_data: ReshapedData,
raise NotImplementedError
def excluded_triggers(self):
- """Trigger types to exclude when decoding"""
+ """Trigger types to exclude when decoding."""
return [TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION]
def devices_compatible(self, model_device: DeviceSpec,
data_device: DeviceSpec) -> bool:
"""Check compatibility between the device on which the model was trained
- and the device used for data collection."""
-
+ and the device used for data collection.
+ """
# TODO: check analysis channels?
return model_device.sample_rate == data_device.sample_rate
def parameters_compatible(self, sim_timing_params: TimingParams,
data_timing_params: TimingParams) -> bool:
- """Check compatibility between the parameters used for simulation and
- those used for data collection."""
+ """Check compatibility between the parameters used for simulation and those used for data collection."""
return sim_timing_params.time_flash == data_timing_params.time_flash
class EEGRawDataProcessor(RawDataProcessor):
"""RawDataProcessor that processes EEG data."""
+
consumes = ContentType.EEG
produces = EvidenceType.ERP
@@ -438,7 +440,7 @@ def extract_trials(self, filtered_reshaped_data: ReshapedData,
def get_transform(self, transform_params: ERPTransformParams,
data_sample_rate: int) -> Composition:
- """"Get the transform used for filtering the data."""
+ """Get the transform used for filtering the data."""
return get_default_transform(
sample_rate_hz=data_sample_rate,
notch_freq_hz=transform_params.notch_filter_frequency,
diff --git a/bcipy/simulator/data/sampler/base_sampler.py b/bcipy/simulator/data/sampler/base_sampler.py
index 18f35b1a1..87af04e86 100644
--- a/bcipy/simulator/data/sampler/base_sampler.py
+++ b/bcipy/simulator/data/sampler/base_sampler.py
@@ -34,7 +34,8 @@ def format_samples(sample_rows: List[Trial]) -> str:
class Sampler(ABC):
"""Represents a strategy for sampling signal model data from a DataEngine
- comprised of signal data from one or more data collection sessions."""
+ comprised of signal data from one or more data collection sessions.
+ """
def __init__(self, data_engine: RawDataEngine):
self.data_engine: RawDataEngine = data_engine
diff --git a/bcipy/simulator/data/sampler/inquiry_sampler.py b/bcipy/simulator/data/sampler/inquiry_sampler.py
index 72b1a1133..d4a7cc214 100644
--- a/bcipy/simulator/data/sampler/inquiry_sampler.py
+++ b/bcipy/simulator/data/sampler/inquiry_sampler.py
@@ -28,9 +28,9 @@ def __init__(self, data_engine: RawDataEngine):
def prepare_data(
self, data: pd.DataFrame
) -> Tuple[Dict[Path, List[int]], Dict[Path, List[int]]]:
- """Partition the data into those inquiries that displayed the target
- and those that did not. The resulting data structures map the data
- source with a list of inquiry_n numbers."""
+ """Partition the data into those inquiries that displayed the target and those that did not.
+ The resulting data structures map the data source with a list of inquiry_n numbers.
+ """
target_inquiries = defaultdict(list)
no_target_inquiries = defaultdict(list)
diff --git a/bcipy/simulator/data/sampler/target_nontarget_sampler.py b/bcipy/simulator/data/sampler/target_nontarget_sampler.py
index 59cce32bd..71e4e69f0 100644
--- a/bcipy/simulator/data/sampler/target_nontarget_sampler.py
+++ b/bcipy/simulator/data/sampler/target_nontarget_sampler.py
@@ -13,6 +13,7 @@ class TargetNontargetSampler(Sampler):
"""Sampler that that queries based on target/non-target label."""
def sample(self, state: SimState) -> List[Trial]:
+ """Sample trials for each symbol in the display alphabet, labeling as target or non-target."""
sample_rows = []
for symbol in state.display_alphabet:
filters = self.query_filters(
diff --git a/bcipy/simulator/data/trial.py b/bcipy/simulator/data/trial.py
index e7480fdd5..58ff9e20d 100644
--- a/bcipy/simulator/data/trial.py
+++ b/bcipy/simulator/data/trial.py
@@ -34,7 +34,8 @@ class Trial(NamedTuple):
inquiry_pos: int
symbol: str
target: int
- eeg: np.ndarray # Channels by Samples ; ndarray.shape = (channel_n, sample_n)
+ # Channels by Samples ; ndarray.shape = (channel_n, sample_n)
+ eeg: np.ndarray
def __str__(self):
fields = [
@@ -51,8 +52,7 @@ def __repr__(self):
def session_series_counts(source_dir: str) -> List[int]:
- """Read the session.json file in the provided directory
- and compute the number of inquiries per series."""
+ """Read the session.json file in the provided directory and compute the number of inquiries per series."""
session_path = Path(source_dir, SESSION_DATA_FILENAME)
if session_path.exists():
session = read_session(str(session_path))
diff --git a/bcipy/simulator/demo/demo_group_simulation.py b/bcipy/simulator/demo/demo_group_simulation.py
index 592798e4f..5b6e84b64 100644
--- a/bcipy/simulator/demo/demo_group_simulation.py
+++ b/bcipy/simulator/demo/demo_group_simulation.py
@@ -102,7 +102,8 @@ def run_simulation(
-------
The simulation results will be saved in the output directory defined by OUTPUT_DIR.
"""
- print(f"Running simulation for {user} with phrase {phrase} and language model {language_model}")
+ print(
+ f"Running simulation for {user} with phrase {phrase} and language model {language_model}")
model_path = None
params_path = None
@@ -120,7 +121,8 @@ def run_simulation(
model_path = pkl_file
break
if not model_path:
- raise FileNotFoundError(f"Could not find a model file in {mode_calibration_dir}")
+ raise FileNotFoundError(
+ f"Could not find a model file in {mode_calibration_dir}")
else:
model_path = signal_model_path
@@ -132,7 +134,8 @@ def run_simulation(
# update the parameters with the new phrase, starting index, and language model
phrase_length = len(phrase) - starting_index
- print(f"Processing {user}:{phrase}:{language_model} of phrase length: {phrase_length}")
+ print(
+ f"Processing {user}:{phrase}:{language_model} of phrase length: {phrase_length}")
parameters["task_text"] = phrase
parameters["spelled_letters_count"] = starting_index
parameters["lang_model_type"] = language_model
@@ -140,7 +143,8 @@ def run_simulation(
# Below are task constraints that impact typing speed and letter selection. Here we use criteria based on phrase length
# and some sensible defaults.
parameters["max_inq_len"] = phrase_length * 8
- parameters["max_selections"] = phrase_length * 2 # This should be 2 * the length of the phrase to type
+ # This should be 2 * the length of the phrase to type
+ parameters["max_selections"] = phrase_length * 2
parameters["min_inq_per_series"] = 1
parameters["max_inq_per_series"] = 8
parameters["backspace_always_shown"] = True
@@ -148,11 +152,14 @@ def run_simulation(
parameters["lm_backspace_prob"] = 0.03571
# signal model decision threshold for letter selection
parameters["decision_threshold"] = 0.8
- parameters["max_minutes"] = 120 # This is not used in the simulation. But we set it high to avoid any issues.
- parameters["max_incorrect"] = int(phrase_length / 2) # This should be half the length of the phrase to type
+ # This is not used in the simulation. But we set it high to avoid any issues.
+ parameters["max_minutes"] = 120
+ # This should be half the length of the phrase to type
+ parameters["max_incorrect"] = int(phrase_length / 2)
# get the correct list of source directories from the data directory
- source_dirs = [str(file) for file in data_dir.iterdir() if file.is_dir() and DATA_PATTERN in file.name]
+ source_dirs = [str(file) for file in data_dir.iterdir()
+ if file.is_dir() and DATA_PATTERN in file.name]
if not source_dirs:
raise FileNotFoundError(f"Could not find a data directory for {user}")
@@ -171,7 +178,8 @@ def run_simulation(
runner.run()
metrics.report(sim_dir)
except Exception as e:
- print(f"Error running simulation for {user} with phrase {phrase} and language model {language_model}")
+ print(
+ f"Error running simulation for {user} with phrase {phrase} and language model {language_model}")
print(e)
@@ -201,6 +209,7 @@ def run_simulation(
progress_bar.set_description(f"Processing {user.name}")
for phrase, starting_index in PHRASES:
for language_model in LANGUAGE_MODELS:
- run_simulation(user, user.name, phrase, starting_index, language_model)
+ run_simulation(user, user.name, phrase,
+ starting_index, language_model)
progress_bar.close()
diff --git a/bcipy/simulator/exceptions.py b/bcipy/simulator/exceptions.py
index f95f5f1a6..80e69a27d 100644
--- a/bcipy/simulator/exceptions.py
+++ b/bcipy/simulator/exceptions.py
@@ -1,6 +1,5 @@
class DeviceSpecNotFoundError(Exception):
- """Thrown when a suitable DeviceSpec was not found in the devices.json file
- """
+ """Thrown when a suitable DeviceSpec was not found in the devices.json file."""
class IncompatibleData(Exception):
@@ -15,7 +14,8 @@ class IncompatibleDeviceSpec(IncompatibleData):
class IncompatibleParameters(IncompatibleData):
"""Thrown when the timing parameters used for data collection are
- incompatible with the timing parameters of the simulation."""
+ incompatible with the timing parameters of the simulation.
+ """
class IncompatibleSampler(Exception):
diff --git a/bcipy/simulator/task/copy_phrase.py b/bcipy/simulator/task/copy_phrase.py
index c9129e1d0..fa8a1fb76 100644
--- a/bcipy/simulator/task/copy_phrase.py
+++ b/bcipy/simulator/task/copy_phrase.py
@@ -1,4 +1,4 @@
-# mypy: disable-error-code="union-attr"
+# mypy: disable-error-code="union-attr, override"
"""Simulates the Copy Phrase task"""
import logging
from pathlib import Path
@@ -37,12 +37,14 @@ def get_evidence_type(model: SignalModel) -> EvidenceType:
try:
return EvidenceType(evidence_type)
except ValueError:
- raise ValueError(f"Unsupported evidence type: {evidence_type}. Supported types: {EvidenceType.list()}")
+ raise ValueError(
+ f"Unsupported evidence type: {evidence_type}. Supported types: {EvidenceType.list()}")
class SimulatorCopyPhraseTask(RSVPCopyPhraseTask):
"""CopyPhraseTask that simulates user interactions by sampling data
- from a DataSampler."""
+ from a DataSampler.
+ """
name = "Simulator Copy Phrase"
paradigm = "RSVP"
@@ -83,6 +85,7 @@ def get_language_model(self):
def init_evidence_evaluators(
self, signal_models: List[SignalModel]) -> List[EvidenceEvaluator]:
+ """Initialize evidence evaluators for the simulation (returns empty list)."""
# Evidence will be sampled so we don't need to evaluate raw data.
return []
@@ -90,21 +93,25 @@ def init_evidence_types(
self, signal_models: List[SignalModel],
evidence_evaluators: List[EvidenceEvaluator]
) -> List[EvidenceType]:
+ """Initialize evidence types for the simulation."""
evidence_types = set(
[get_evidence_type(model) for model in self.signal_models])
return [EvidenceType.LM, *evidence_types]
def init_display(self) -> Display:
+ """Initialize the display for the simulation (returns NullDisplay)."""
return NullDisplay()
def init_acquisition(self) -> ClientManager:
- """Override to do nothing"""
+ """Override to do nothing."""
return NullDAQ()
def init_feedback(self) -> Optional[VisualFeedback]:
+ """Initialize feedback for the simulation (returns None)."""
return None
def user_wants_to_continue(self) -> bool:
+ """Always continue for simulation purposes."""
return True
def wait(self, seconds: Optional[float] = None) -> None:
@@ -113,8 +120,7 @@ def wait(self, seconds: Optional[float] = None) -> None:
def present_inquiry(
self, inquiry_schedule: InquirySchedule
) -> Tuple[List[Tuple[str, float]], bool]:
- """Override ; returns empty timing info; always proceed for inquiry
- preview"""
+ """Override; returns empty timing info; always proceed for inquiry preview."""
return [], True
def show_feedback(self, selection: str, correct: bool = True) -> None:
@@ -122,12 +128,14 @@ def show_feedback(self, selection: str, correct: bool = True) -> None:
def compute_button_press_evidence(
self, proceed: bool) -> Optional[Tuple[EvidenceType, List[float]]]:
+ """Compute button press evidence for simulation (returns None)."""
return None
def compute_device_evidence(
self,
stim_times: List[List],
proceed: bool = True) -> List[Tuple[EvidenceType, List[float]]]:
+ """Compute device evidence for simulation."""
current_state = self.get_sim_state()
self.logger.debug("Computing evidence with sim_state:")
@@ -149,6 +157,7 @@ def compute_device_evidence(
return evidences
def cleanup(self):
+ """Cleanup after simulation, including saving session data and removing empty trigger files."""
self.save_session_data()
trigger_path = Path(self.trigger_handler.file_path)
self.trigger_handler.close()
@@ -161,6 +170,7 @@ def exit_display(self) -> None:
"""Close the UI and cleanup."""
def elapsed_seconds(self) -> float:
+ """Return elapsed seconds for simulation (always 0.0)."""
return 0.0
def write_offset_trigger(self) -> None:
diff --git a/bcipy/simulator/task/null_display.py b/bcipy/simulator/task/null_display.py
index 4eef7b50a..3db159c7b 100644
--- a/bcipy/simulator/task/null_display.py
+++ b/bcipy/simulator/task/null_display.py
@@ -6,9 +6,11 @@
class NullDisplay(Display):
"""Display that doesn't show anything to the user. Useful in simulated tasks
- that do not have a display component."""
+ that do not have a display component.
+ """
def do_inquiry(self) -> List[Tuple[str, float]]:
+ """Perform an inquiry and return an empty list (no display)."""
return []
def wait_screen(self, *args, **kwargs) -> None:
diff --git a/bcipy/simulator/task/replay_session.py b/bcipy/simulator/task/replay_session.py
index 995f322e9..df60e7fb4 100644
--- a/bcipy/simulator/task/replay_session.py
+++ b/bcipy/simulator/task/replay_session.py
@@ -1,5 +1,7 @@
"""Task that will replay sessions to compare model predictions on that data.
-Used for testing if changes to a model result in more easily differentiated signals."""
+Used for testing if changes to a model result in more easily differentiated signals.
+"""
+
import argparse
import logging
from pathlib import Path
diff --git a/bcipy/simulator/task/task_factory.py b/bcipy/simulator/task/task_factory.py
index 9d661ea5e..aacdc8784 100644
--- a/bcipy/simulator/task/task_factory.py
+++ b/bcipy/simulator/task/task_factory.py
@@ -32,7 +32,6 @@ def update_latest_params(parameters: Parameters) -> None:
class TaskFactory:
"""Constructs the hierarchy of objects necessary for initializing a task.
-
Parameters
----------
parameters : Parameters
@@ -82,7 +81,8 @@ def __init__(
def log_state(self):
"""Log configured objects of interest. This should be done after the
sim directory has been created and TOP_LEVEL_LOGGER has been configured,
- which may happen some time after object construction."""
+ which may happen some time after object construction.
+ """
logger.debug("Language model:")
logger.debug(f"\t{repr(self.language_model)}")
logger.debug("Models -> Samplers:")
diff --git a/bcipy/simulator/task/task_runner.py b/bcipy/simulator/task/task_runner.py
index 69875d74a..b5120e681 100644
--- a/bcipy/simulator/task/task_runner.py
+++ b/bcipy/simulator/task/task_runner.py
@@ -149,7 +149,8 @@ def main():
elif args.interactive:
task_factory = cli.main(sim_args)
else:
- parameters = load_json_parameters(sim_args['parameters'], value_cast=True)
+ parameters = load_json_parameters(
+ sim_args['parameters'], value_cast=True)
task_factory = TaskFactory(
parameters=parameters,
source_dirs=sim_args['data_folder'],
diff --git a/bcipy/simulator/ui/cli.py b/bcipy/simulator/ui/cli.py
index a99cb5ce8..973d0810e 100644
--- a/bcipy/simulator/ui/cli.py
+++ b/bcipy/simulator/ui/cli.py
@@ -31,7 +31,8 @@ def do_directories(parent: Path,
max_depth: int = 3,
current_depth: int = 1) -> None:
"""Recursively walk a tree of directories, calling the provided
- function on each path."""
+ function on each path.
+ """
paths = sorted(
Path(parent).iterdir(),
@@ -108,7 +109,8 @@ def excluded(path: Path) -> bool:
def select_directories(parent: Path) -> List[Path]:
"""Select all directories of interest within the parent path.
- Traverses all directories and prompts for each."""
+ Traverses all directories and prompts for each.
+ """
accum = []
def do_prompt(path: Path):
@@ -155,6 +157,7 @@ class PromptPath(PromptBase[Path]):
response_type = Path
def process_response(self, value: str) -> Path:
+ """Process the response value and return a Path object."""
value = value.strip("'\"")
return super().process_response(value)
@@ -229,8 +232,9 @@ def get_acq_mode(params_path: str):
def command(params: str, models: List[str], source_dirs: List[str], sampler: Type[Sampler]) -> str:
- """Command equivalent to to the result of the interactive selection of
- simulator inputs."""
+ """Command equivalent to the result of the interactive selection of
+ simulator inputs.
+ """
model_args = ' '.join([f"-m {path}" for path in models])
dir_args = ' '.join(f"-d {source}" for source in source_dirs)
sampler_args = f"-s {sampler.__name__}"
diff --git a/bcipy/simulator/ui/gui.py b/bcipy/simulator/ui/gui.py
index edfb42bbb..0d61303f4 100644
--- a/bcipy/simulator/ui/gui.py
+++ b/bcipy/simulator/ui/gui.py
@@ -186,6 +186,7 @@ def __init__(self,
change_event=change_event)
def prompt_path(self):
+ """Prompt the user to select a directory path."""
dialog = FileDialog()
directory = ''
if preferences.last_directory:
@@ -493,13 +494,13 @@ def sim_runs_control(value: int = 1) -> QWidget:
class SimConfigForm(QWidget):
- """The InputForm class is a QWidget that creates controls/inputs for each simulation parameter
+ """The InputForm class is a QWidget that creates controls/inputs for each simulation parameter.
- Parameters:
- -----------
- json_file - path of parameters file to be edited.
- width - optional; used to set the width of the form controls.
- """
+ Parameters
+ ----------
+ json_file - path of parameters file to be edited.
+ width - optional; used to set the width of the form controls.
+ """
def __init__(self,
width: int = 400,
@@ -591,8 +592,7 @@ def command_valid(self) -> bool:
and self.data_paths and self.sampler and self.sampler_args)
def command(self) -> str:
- """Command equivalent to to the result of the interactive selection of
- simulator inputs."""
+ """Command equivalent to the result of the interactive selection of simulator inputs."""
if not self.command_valid():
return ''
diff --git a/bcipy/simulator/ui/obj_args_widget.py b/bcipy/simulator/ui/obj_args_widget.py
index d59113a53..6aa35a9d4 100644
--- a/bcipy/simulator/ui/obj_args_widget.py
+++ b/bcipy/simulator/ui/obj_args_widget.py
@@ -11,7 +11,8 @@
class ObjectArgInputs(QWidget):
"""Widget with inputs for parameters needed to instantiate an object for a
- given class."""
+ given class.
+ """
def __init__(self,
parent: Optional[QWidget] = None,
@@ -69,7 +70,8 @@ def _field_definition(self, name: str) -> InputField:
def _json_name_value(self, name: str, control: QWidget) -> str:
"""Returns a json partial for a name: value, quoting the value according
- to the input_type"""
+ to the input_type
+ """
field = self._field_definition(name)
value = self._input_value(field, control)
diff --git a/bcipy/simulator/util/artifact.py b/bcipy/simulator/util/artifact.py
index 03329863a..2859ff7af 100644
--- a/bcipy/simulator/util/artifact.py
+++ b/bcipy/simulator/util/artifact.py
@@ -1,4 +1,5 @@
-""" Handles artifacts related logic ie. logs, save dir creation, result.json, ..."""
+"""Handles artifacts related logic ie. logs, save dir creation, result.json, ..."""
+
import datetime
import logging
import os
diff --git a/bcipy/simulator/util/generate_marker_data.py b/bcipy/simulator/util/generate_marker_data.py
index 6cb82a4ef..634343253 100644
--- a/bcipy/simulator/util/generate_marker_data.py
+++ b/bcipy/simulator/util/generate_marker_data.py
@@ -8,7 +8,7 @@
def main() -> Path:
- """"Main method used to generate a raw data file for a switch device."""
+ """Main method used to generate a raw data file for a switch device."""
parser = argparse.ArgumentParser(
description="Create raw marker data for a given session.")
parser.add_argument("data_folder",
diff --git a/bcipy/simulator/util/metrics.py b/bcipy/simulator/util/metrics.py
index b4b7a91b0..edecb742e 100644
--- a/bcipy/simulator/util/metrics.py
+++ b/bcipy/simulator/util/metrics.py
@@ -168,8 +168,7 @@ def plot_results(df: pd.DataFrame,
def report(sim_dir: str, show_plots: bool = False) -> None:
- """Summarize the data, write as a JSON file, and output a summary to
- the top level log file."""
+ """Summarize the data, write as a JSON file, and output a summary to the top level log file."""
summary = summarize(sim_dir)
save_json_data(summary, sim_dir, SUMMARY_DATA_FILE_NAME)
diff --git a/bcipy/simulator/util/state.py b/bcipy/simulator/util/state.py
index 5ae25ba43..abb6ca4ce 100644
--- a/bcipy/simulator/util/state.py
+++ b/bcipy/simulator/util/state.py
@@ -10,7 +10,8 @@
@dataclass
class SimState:
- """ Represents the state of a current session during simulation """
+ """Represents the state of a current session during simulation."""
+
target_symbol: str
current_sentence: str
target_sentence: str
@@ -20,8 +21,7 @@ class SimState:
def get_inquiry(session_dir: str, n: int) -> Dict[str, Any]:
- """Extracts an inquiry from a session.json file. Useful for debugging
- simulator output."""
+ """Extracts an inquiry from a session.json file. Useful for debugging simulator output."""
session = read_session(f"{session_dir}/{SESSION_DATA_FILENAME}")
inq = session.all_inquiries[n]
return inq.stim_evidence(session.symbol_set)
diff --git a/bcipy/simulator/util/switch_utils.py b/bcipy/simulator/util/switch_utils.py
index b60e75725..680412586 100644
--- a/bcipy/simulator/util/switch_utils.py
+++ b/bcipy/simulator/util/switch_utils.py
@@ -49,16 +49,13 @@ def has_target(triggers: List[Trigger]) -> bool:
def time_range(inquiry_triggers: List[Trigger],
time_flash: float) -> Tuple[float, float]:
- """Given a list of triggers for a given inquiry, determine the start and
- end timestamps of that inquiry."""
+ """Given a list of triggers for a given inquiry, determine the start and end timestamps of that inquiry."""
return (inquiry_triggers[0].time, inquiry_triggers[-1].time + time_flash)
def inquiry_windows(trigger_path: Path,
time_flash: float) -> List[Tuple[float, float]]:
- """Returns a list of (inquiry_start, inquiry_stop) timestamp pairs for
- all inquiries in the trigger file."""
-
+ """Returns a list of (inquiry_start, inquiry_stop) timestamp pairs for all inquiries in the trigger file."""
return [
time_range(inq_triggers, time_flash)
for inq_triggers in partition_triggers(trigger_path)
@@ -67,8 +64,7 @@ def inquiry_windows(trigger_path: Path,
def should_press_switch(inquiry_triggers: List[Trigger],
button_press_mode: ButtonPressMode) -> bool:
- """Determine if a marker should be written for the given inquiry
- depending on the presence of a target and the button press mode."""
+ """Determine if a marker should be written for the given inquiry depending on the presence of a target and the button press mode."""
return (button_press_mode == ButtonPressMode.ACCEPT
and has_target(inquiry_triggers)) or (
button_press_mode == ButtonPressMode.REJECT
diff --git a/bcipy/static/images/gui/cambi_fav.ico b/bcipy/static/images/gui/cambi_fav.ico
new file mode 100644
index 000000000..897130825
Binary files /dev/null and b/bcipy/static/images/gui/cambi_fav.ico differ
diff --git a/bcipy/task/actions.py b/bcipy/task/actions.py
index 5472c7987..ac943ce9f 100644
--- a/bcipy/task/actions.py
+++ b/bcipy/task/actions.py
@@ -1,9 +1,16 @@
# mypy: disable-error-code="assignment,arg-type"
+"""Task actions module for BCI tasks.
+
+This module provides various task actions that can be executed as part of a BCI
+experiment, including code hooks, offline analysis, intertask management, and
+report generation.
+"""
+
import glob
import logging
import subprocess
from pathlib import Path
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from matplotlib.figure import Figure
@@ -32,8 +39,13 @@
class CodeHookAction(Task):
- """
- Action for running generic code hooks.
+ """Action for running generic code hooks.
+
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ code_hook: Code to execute.
+ subprocess: Whether to run in a subprocess.
"""
name = "CodeHookAction"
@@ -46,23 +58,43 @@ def __init__(
code_hook: Optional[str] = None,
subprocess: bool = True,
**kwargs) -> None:
+ """Initialize the code hook action.
+
+ Args:
+ parameters: Task parameters.
+ data_directory: Directory for data storage.
+ code_hook: Code to execute.
+ subprocess: Whether to run in a subprocess.
+ **kwargs: Additional keyword arguments.
+ """
super().__init__()
self.code_hook = code_hook
self.subprocess = subprocess
def execute(self) -> TaskData:
+ """Execute the code hook.
+
+ Returns:
+ TaskData: Empty task data.
+ """
if self.code_hook:
if self.subprocess:
subprocess.Popen(self.code_hook, shell=True)
-
else:
subprocess.run(self.code_hook, shell=True)
return TaskData()
class OfflineAnalysisAction(Task):
- """
- Action for running offline analysis.
+ """Action for running offline analysis.
+
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ parameters: Task parameters.
+ parameters_path: Path to parameters file.
+ data_directory: Directory containing data to analyze.
+ alert_finished: Whether to alert when analysis completes.
"""
name = "OfflineAnalysisAction"
@@ -76,6 +108,16 @@ def __init__(
last_task_dir: Optional[str] = None,
alert_finished: bool = False,
**kwargs: Any) -> None:
+ """Initialize the offline analysis action.
+
+ Args:
+ parameters: Task parameters.
+ data_directory: Directory containing data to analyze.
+ parameters_path: Path to parameters file.
+ last_task_dir: Directory of last executed task.
+ alert_finished: Whether to alert when analysis completes.
+ **kwargs: Additional keyword arguments.
+ """
super().__init__()
self.parameters = parameters
self.parameters_path = parameters_path
@@ -94,6 +136,11 @@ def execute(self) -> TaskData:
to stop execution. For example, if Exception is thrown in cross_validation due to the # of folds being
inconsistent.
+ Returns:
+ TaskData: Contains analysis results and parameters.
+
+ Raises:
+ Exception: If offline analysis fails.
"""
logger.info("Running offline analysis action")
try:
@@ -118,6 +165,21 @@ def execute(self) -> TaskData:
class IntertaskAction(Task):
+ """Action for managing transitions between tasks.
+
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ tasks: List of tasks to manage.
+ current_task_index: Index of current task.
+ save_folder: Directory for saving task data.
+ parameters: Task parameters.
+ next_task_index: Index of next task to execute.
+ task_name: Name of current task.
+ task_names: List of task names.
+ exit_callback: Function to call on exit.
+ """
+
name = "IntertaskAction"
mode = TaskMode.ACTION
tasks: List[Task]
@@ -131,11 +193,25 @@ def __init__(
tasks: Optional[List[Task]] = None,
exit_callback: Optional[Callable] = None,
**kwargs: Any) -> None:
+ """Initialize the intertask action.
+
+ Args:
+ parameters: Task parameters.
+ save_path: Directory for saving task data.
+ progress: Current progress (1-indexed).
+ tasks: List of tasks to manage.
+ exit_callback: Function to call on exit.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ AssertionError: If progress or tasks is None, or if progress < 0.
+ """
super().__init__()
self.save_folder = save_path
self.parameters = parameters
assert progress is not None and tasks is not None, "Either progress or tasks must be provided"
- self.next_task_index = progress # progress is 1-indexed, tasks is 0-indexed so we can use the same index
+ # progress is 1-indexed, tasks is 0-indexed so we can use the same index
+ self.next_task_index = progress
assert self.next_task_index >= 0, "Progress must be greater than 1 "
self.tasks = tasks
self.task_name = self.tasks[self.next_task_index].name
@@ -143,7 +219,11 @@ def __init__(
self.exit_callback = exit_callback
def execute(self) -> TaskData:
+ """Execute the intertask action.
+ Returns:
+ TaskData: Contains task state information.
+ """
run_bciui(
IntertaskGUI,
tasks=self.task_names,
@@ -160,12 +240,19 @@ def execute(self) -> TaskData:
)
def alert(self):
+ """Handle alerts (not implemented)."""
pass
class ExperimentFieldCollectionAction(Task):
- """
- Action for collecting experiment field data.
+ """Action for collecting experiment field data.
+
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ experiment_id: Identifier for the experiment.
+ save_folder: Directory for saving collected data.
+ parameters: Task parameters.
"""
name = "ExperimentFieldCollectionAction"
@@ -177,16 +264,30 @@ def __init__(
data_directory: str,
experiment_id: str = 'default',
**kwargs: Any) -> None:
+ """Initialize the experiment field collection action.
+
+ Args:
+ parameters: Task parameters.
+ data_directory: Directory for saving collected data.
+ experiment_id: Identifier for the experiment.
+ **kwargs: Additional keyword arguments.
+ """
super().__init__()
self.experiment_id = experiment_id
self.save_folder = data_directory
self.parameters = parameters
def execute(self) -> TaskData:
+ """Execute the experiment field collection.
+
+ Returns:
+ TaskData: Contains experiment metadata.
+ """
logger.info(
f"Collecting experiment field data for experiment {self.experiment_id} in save folder {self.save_folder}"
)
- start_experiment_field_collection_gui(self.experiment_id, self.save_folder)
+ start_experiment_field_collection_gui(
+ self.experiment_id, self.save_folder)
return TaskData(
save_path=self.save_folder,
task_dict={
@@ -196,8 +297,22 @@ def execute(self) -> TaskData:
class BciPyCalibrationReportAction(Task):
- """
- Action for generating a report after calibration Tasks.
+ """Action for generating a report after calibration Tasks.
+
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ parameters: Task parameters.
+ save_folder: Directory for saving reports.
+ protocol_path: Path to protocol file.
+ last_task_dir: Directory of last executed task.
+ trial_window: Time window for trial analysis.
+ report: Report instance.
+ report_sections: List of report sections.
+ all_raw_data: List of raw data.
+ default_transform: Signal transformation function.
+ type_amp: Amplifier type.
+ static_offset: Static offset value.
"""
name = "BciPyReportAction"
@@ -211,6 +326,16 @@ def __init__(
last_task_dir: Optional[str] = None,
trial_window: Optional[Tuple[float, float]] = None,
**kwargs: Any) -> None:
+ """Initialize the calibration report action.
+
+ Args:
+ parameters: Task parameters.
+ save_path: Directory for saving reports.
+ protocol_path: Path to protocol file.
+ last_task_dir: Directory of last executed task.
+ trial_window: Time window for trial analysis.
+ **kwargs: Additional keyword arguments.
+ """
super().__init__()
self.save_folder = save_path
# Currently we assume all Tasks have the same parameters, this may change in the future.
@@ -233,9 +358,12 @@ def __init__(
self.static_offset = None
def execute(self) -> TaskData:
- """Excute the report generation action.
+ """Execute the report generation action.
This assumes all data were collected using the same protocol, device, and parameters.
+
+ Returns:
+ TaskData: Contains report data and metadata.
"""
logger.info(f"Generating report in save folder {self.save_folder}")
# loop through all the files in the last_task_dir
@@ -254,14 +382,17 @@ def execute(self) -> TaskData:
task_name = path_data_dir.parts[-1].split('_')[0]
data_directories.append(path_data_dir)
# For each calibration directory, attempt to load the raw data
- signal_report_section = self.create_signal_report(path_data_dir)
- session_report = self.create_session_report(path_data_dir, task_name)
+ signal_report_section = self.create_signal_report(
+ path_data_dir)
+ session_report = self.create_session_report(
+ path_data_dir, task_name)
self.report_sections.append(session_report)
self.report.add(session_report)
self.report_sections.append(signal_report_section)
self.report.add(signal_report_section)
if data_directories:
- logger.info(f"Saving report generated from: {self.protocol_path}")
+ logger.info(
+ f"Saving report generated from: {self.protocol_path}")
else:
logger.info(f"No data found in {self.protocol_path}")
@@ -270,6 +401,7 @@ def execute(self) -> TaskData:
self.report.compile()
self.report.save()
+
return TaskData(
save_path=self.save_folder,
task_dict={
@@ -278,6 +410,14 @@ def execute(self) -> TaskData:
)
def create_signal_report(self, data_dir: Path) -> SignalReportSection:
+ """Create a report section for signal quality metrics.
+
+ Args:
+ data_dir: Directory containing signal data.
+
+ Returns:
+ SignalReportSection: Report section containing signal metrics.
+ """
raw_data = load_raw_data(Path(data_dir, f'{RAW_DATA_FILENAME}.csv'))
if not self.type_amp:
self.type_amp = raw_data.daq_type
@@ -298,11 +438,22 @@ def create_signal_report(self, data_dir: Path) -> SignalReportSection:
triggers = self.get_triggers(data_dir)
# get figure handles
- figure_handles = self.get_figure_handles(raw_data, channel_map, triggers)
- artifact_detector = self.get_artifact_detector(raw_data, device_spec, triggers)
+ figure_handles = self.get_figure_handles(
+ raw_data, channel_map, triggers)
+ artifact_detector = self.get_artifact_detector(
+ raw_data, device_spec, triggers)
return SignalReportSection(figure_handles, artifact_detector)
- def create_session_report(self, data_dir, task_name) -> SessionReportSection:
+ def create_session_report(self, data_dir: Path, task_name: str) -> SessionReportSection:
+ """Create a report section for session information.
+
+ Args:
+ data_dir: Directory containing session data.
+ task_name: Name of the task.
+
+ Returns:
+ SessionReportSection: Report section containing session info.
+ """
# get task name
summary_dict = {
"task": task_name,
@@ -314,11 +465,17 @@ def create_session_report(self, data_dir, task_name) -> SessionReportSection:
return SessionReportSection(summary_dict)
- def get_signal_model_metrics(self, data_directory: Path) -> dict:
+ def get_signal_model_metrics(self, data_directory: Path) -> Dict[str, Any]:
"""Get the signal model metrics from the session folder.
In the future, the model will save a ModelMetrics with the pkl file.
For now, we just look for the pkl file and extract the AUC from the filename.
+
+ Args:
+ data_directory: Directory containing model data.
+
+ Returns:
+ Dict[str, Any]: Dictionary of model metrics.
"""
pkl_file = None
for file in data_directory.iterdir():
@@ -334,6 +491,11 @@ def get_signal_model_metrics(self, data_directory: Path) -> dict:
return {'AUC': auc}
def set_default_transform(self, sample_rate: int) -> None:
+ """Set the default signal transformation function.
+
+ Args:
+ sample_rate: Sampling rate of the signal.
+ """
downsample_rate = self.parameters.get("down_sampling_rate")
notch_filter = self.parameters.get("notch_filter_frequency")
filter_high = self.parameters.get("filter_high")
@@ -348,7 +510,15 @@ def set_default_transform(self, sample_rate: int) -> None:
downsample_factor=downsample_rate,
)
- def find_eye_channels(self, device_spec: DeviceSpec) -> Optional[list]:
+ def find_eye_channels(self, device_spec: DeviceSpec) -> Optional[List[str]]:
+ """Find eye-tracking channels in the device specification.
+
+ Args:
+ device_spec: Device specification.
+
+ Returns:
+ Optional[List[str]]: List of eye channel names if found.
+ """
eye_channels = []
for channel in device_spec.channels:
if 'F' in channel:
@@ -357,7 +527,15 @@ def find_eye_channels(self, device_spec: DeviceSpec) -> Optional[list]:
eye_channels = None
return eye_channels
- def get_triggers(self, session: str) -> tuple:
+ def get_triggers(self, session: str) -> Tuple[List[Any], List[float], List[str]]:
+ """Get triggers from the session data.
+
+ Args:
+ session: Path to session directory.
+
+ Returns:
+ Tuple[List[Any], List[float], List[str]]: Trigger type, timing, and labels.
+ """
trigger_type, trigger_timing, trigger_label = trigger_decoder(
offset=self.static_offset,
trigger_path=f"{session}/{TRIGGER_FILENAME}",
@@ -369,7 +547,18 @@ def get_triggers(self, session: str) -> tuple:
)
return trigger_type, trigger_timing, trigger_label
- def get_figure_handles(self, raw_data, channel_map, triggers) -> List[Figure]:
+ def get_figure_handles(self, raw_data: RawData, channel_map: List[str],
+ triggers: Tuple[TriggerType, List[float], List[str]]) -> List[Figure]:
+ """Generate figures for the report.
+
+ Args:
+ raw_data: Raw signal data.
+ channel_map: List of channel names.
+ triggers: Tuple of trigger type, timing, and labels.
+
+ Returns:
+ List[Figure]: List of generated figures.
+ """
trigger_type, trigger_timing, _ = triggers
figure_handles = visualize_erp(
raw_data,
@@ -384,7 +573,18 @@ def get_figure_handles(self, raw_data, channel_map, triggers) -> List[Figure]:
)
return figure_handles
- def get_artifact_detector(self, raw_data, device_spec, triggers) -> ArtifactDetection:
+ def get_artifact_detector(self, raw_data: RawData, device_spec: DeviceSpec,
+ triggers: Tuple[TriggerType, List[float], List[str]]) -> ArtifactDetection:
+ """Create an artifact detector for the signal data.
+
+ Args:
+ raw_data: Raw signal data.
+ device_spec: Device specification.
+ triggers: Tuple of trigger type, timing, and labels.
+
+ Returns:
+ ArtifactDetection: Configured artifact detector.
+ """
eye_channels = self.find_eye_channels(device_spec)
artifact_detector = ArtifactDetection(
raw_data,
diff --git a/bcipy/task/calibration.py b/bcipy/task/calibration.py
index 22fdbae46..9f416534d 100644
--- a/bcipy/task/calibration.py
+++ b/bcipy/task/calibration.py
@@ -1,4 +1,5 @@
"""Base calibration task."""
+# mypy: disable-error-code="override"
import logging
from abc import abstractmethod
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple
@@ -126,6 +127,7 @@ def setup(
parameters: Parameters,
data_save_location: str,
fake: bool = False) -> Tuple[ClientManager, List[LslDataServer], Window]:
+ """Set up acquisition and return client manager, data servers, and window."""
# Initialize Acquisition
daq, servers = init_acquisition(
parameters, data_save_location, server=fake)
@@ -213,7 +215,7 @@ def init_session(self) -> session_data.Session:
task_data=self.session_task_data())
def session_task_data(self) -> Optional[Dict[str, Any]]:
- """"Task-specific session data"""
+ """Task-specific session data."""
return None
def trigger_type(self,
@@ -346,8 +348,7 @@ def write_trigger_data(self, timing: List[Tuple[str, float]],
convert_timing_triggers(timing, timing[0][0], self.trigger_type))
def write_offset_trigger(self) -> None:
- """Append an offset value to the end of the trigger file.
- """
+ """Append an offset value to the end of the trigger file."""
# To help support future refactoring or use of lsl timestamps only
# we write only the sample offset here.
triggers = []
diff --git a/bcipy/task/control/criteria.py b/bcipy/task/control/criteria.py
index e5da2512a..b51eb1d82 100644
--- a/bcipy/task/control/criteria.py
+++ b/bcipy/task/control/criteria.py
@@ -1,6 +1,13 @@
+"""Decision criteria module for BCI task control.
+
+This module provides classes for evaluating decision criteria in BCI tasks.
+These criteria are used to determine when to stop collecting evidence and
+make a decision based on the accumulated data.
+"""
+
import logging
from copy import copy
-from typing import Dict, List
+from typing import Any, Dict, List, Optional
import numpy as np
@@ -10,48 +17,79 @@
class DecisionCriteria:
- """Abstract class for Criteria which can be applied to evaluate a inquiry
+ """Abstract base class for decision criteria evaluation.
+
+ This class defines the interface for criteria that can be applied to
+ evaluate whether a decision should be made based on accumulated evidence.
+
+ Attributes:
+ None
"""
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize the decision criteria.
+
+ Args:
+ **kwargs: Arbitrary keyword arguments.
+ """
pass
- def reset(self):
+ def reset(self) -> None:
+ """Reset the criteria state."""
pass
- def decide(self, series: Dict):
- """
- Apply the given criteria.
- Parameters:
- -----------
- series - series data
- - target(str): target of the series
- - time_spent(ndarray[float]): |num_trials|x1
- time spent on the inquiry
- - list_sti(list[list[str]]): presented symbols in each
- inquiry
- - list_distribution(list[ndarray[float]]): list of |alp|x1
- arrays with prob. dist. over alp
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Apply the decision criteria to the given series data.
+
+ Args:
+ series: Dictionary containing series data with the following keys:
+ target (str): Target of the series.
+ time_spent (np.ndarray): Time spent on each inquiry.
+ list_sti (List[List[str]]): Presented symbols in each inquiry.
+ list_distribution (List[np.ndarray]): Probability distributions
+ over the alphabet for each inquiry.
+
+ Returns:
+ bool: True if the criteria is met, False otherwise.
+ Raises:
+ NotImplementedError: This is an abstract method.
"""
raise NotImplementedError()
class MinIterationsCriteria(DecisionCriteria):
- """ Returns true if the minimum number of iterations have not yet
- been reached. """
+ """Criteria for ensuring a minimum number of iterations.
- def __init__(self, min_num_inq: int):
- """ Args:
- min_num_inq(int): minimum number of inquiry number before any
- termination objective is allowed to be triggered """
+ Returns true if the minimum number of iterations have not yet been reached.
+
+ Attributes:
+ min_num_inq: Minimum number of inquiries required.
+ """
+
+ def __init__(self, min_num_inq: int) -> None:
+ """Initialize the minimum iterations criteria.
+
+ Args:
+ min_num_inq: Minimum number of inquiries required before any
+ termination objective is allowed to be triggered.
+ """
self.min_num_inq = min_num_inq
- def decide(self, series: Dict):
- # Note: we use 'list_sti' parameter since this is the number of
- # inquiries displayed. The length of 'list_distribution' is 1 greater
- # than this, since the language model distribution is added before
- # the first inquiry is displayed.
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Check if minimum number of iterations has been reached.
+
+ Note: Uses 'list_sti' parameter since this is the number of inquiries
+ displayed. The length of 'list_distribution' is 1 greater than this,
+ since the language model distribution is added before the first
+ inquiry is displayed.
+
+ Args:
+ series: Dictionary containing series data.
+
+ Returns:
+ bool: True if current iterations < minimum required, False otherwise.
+ """
current_inq = len(series['list_sti'])
log.info(
f"Checking min iterations; current iteration is {current_inq}")
@@ -59,51 +97,99 @@ def decide(self, series: Dict):
class DecreasedProbabilityCriteria(DecisionCriteria):
- """Returns true if the letter with the max probability decreased from the
- last inquiry."""
+ """Criteria for detecting decreased probability of the most likely symbol.
- def decide(self, series: Dict):
+ Returns true if the letter with the max probability decreased from the
+ last inquiry.
+
+ Attributes:
+ None
+ """
+
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Check if probability of most likely symbol has decreased.
+
+ Args:
+ series: Dictionary containing series data.
+
+ Returns:
+ bool: True if probability decreased for same symbol, False otherwise.
+ """
if len(series['list_distribution']) < 2:
return False
prev_dist = series['list_distribution'][-2]
cur_dist = series['list_distribution'][-1]
- return np.argmax(cur_dist) == np.argmax(
- prev_dist) and np.max(cur_dist) < np.max(prev_dist)
+ return np.argmax(cur_dist) == np.argmax(prev_dist) and np.max(cur_dist) < np.max(prev_dist)
class MaxIterationsCriteria(DecisionCriteria):
- """Returns true if the max iterations have been reached."""
+ """Criteria for enforcing maximum number of iterations.
+
+ Returns true if the maximum allowed iterations have been reached.
+
+ Attributes:
+ max_num_inq: Maximum number of inquiries allowed.
+ """
- def __init__(self, max_num_inq: int):
- """ Args:
- max_num_inq(int): maximum number of inquiries allowed before
- mandatory termination """
+ def __init__(self, max_num_inq: int) -> None:
+ """Initialize the maximum iterations criteria.
+
+ Args:
+ max_num_inq: Maximum number of inquiries allowed before
+ mandatory termination.
+ """
self.max_num_inq = max_num_inq
- def decide(self, series: Dict):
- # Note: len(series['list_sti']) != len(series['list_distribution'])
- # see MinIterationsCriteria comment
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Check if maximum iterations have been reached.
+
+ Note: len(series['list_sti']) != len(series['list_distribution'])
+ See MinIterationsCriteria comment.
+
+ Args:
+ series: Dictionary containing series data.
+
+ Returns:
+ bool: True if max iterations reached, False otherwise.
+ """
current_inq = len(series['list_sti'])
if current_inq >= self.max_num_inq:
- log.info(
- "Committing to decision: max iterations have been reached.")
+ log.info("Committing to decision: max iterations have been reached.")
return True
return False
class ProbThresholdCriteria(DecisionCriteria):
- """Returns true if the commit threshold has been met."""
-
- def __init__(self, threshold: float):
- """ Args:
- threshold(float in [0,1]): A threshold on most likely
- candidate posterior. If a candidate exceeds a posterior
- the system terminates.
- """
+ """Criteria for probability threshold-based decisions.
+
+ Returns true if the commit threshold has been met.
+
+ Attributes:
+ tau: Probability threshold value.
+ """
+
+ def __init__(self, threshold: float) -> None:
+ """Initialize the probability threshold criteria.
+
+ Args:
+ threshold: A threshold value in [0,1]. If a candidate exceeds this
+ posterior probability, the system terminates.
+
+ Raises:
+ AssertionError: If threshold is not in [0,1].
+ """
assert 1 >= threshold >= 0, "stopping threshold should be in [0,1]"
self.tau = threshold
- def decide(self, series: Dict):
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Check if probability threshold has been exceeded.
+
+ Args:
+ series: Dictionary containing series data.
+
+ Returns:
+ bool: True if threshold exceeded, False otherwise.
+ """
current_distribution = series['list_distribution'][-1]
if np.max(current_distribution) > self.tau:
log.info("Committing to decision: posterior exceeded threshold.")
@@ -112,52 +198,79 @@ def decide(self, series: Dict):
class MarginCriteria(DecisionCriteria):
- """ Stopping criteria based on the difference of two most likely candidates.
- This condition enforces the likelihood difference between two top
- candidates to be at least a value. E.g. in 4 category case with
- a margin 0.2, the edge cases [0.6,0.4,0.,0.] and [0.4,0.2,0.2,0.2]
- satisfy the condition.
+ """Criteria based on margin between top two candidates.
+
+ This condition enforces the likelihood difference between two top
+ candidates to be at least a specified value. E.g. in 4 category case with
+ a margin 0.2, the edge cases [0.6,0.4,0.,0.] and [0.4,0.2,0.2,0.2]
+ satisfy the condition.
+
+ Attributes:
+ margin: Required margin between top candidates.
"""
- def __init__(self, margin: float):
- """ Args:
- margin(float in [0,1]): Minimum distance required between
- two most likely competing candidates to trigger termination.
- """
+ def __init__(self, margin: float) -> None:
+ """Initialize the margin criteria.
+
+ Args:
+ margin: Minimum distance required between two most likely competing
+ candidates to trigger termination.
+
+ Raises:
+ AssertionError: If margin is not in [0,1].
+ """
assert 1 >= margin >= 0, "difference margin should be in [0,1]"
self.margin = margin
- def decide(self, series: Dict):
- # Get the current posterior probability values
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Check if margin between top candidates is sufficient.
+
+ Args:
+ series: Dictionary containing series data.
+
+ Returns:
+ bool: True if margin is sufficient, False otherwise.
+ """
p = copy(series['list_distribution'][-1])
- # This criteria compares most likely candidates (best competitors)
candidates = [p[idx] for idx in list(np.argsort(p)[-2:])]
stopping_rule = np.abs(candidates[0] - candidates[1])
d = stopping_rule > self.margin
if d:
log.info("Committing to decision: margin is high enough.")
-
return d
class MomentumCommitCriteria(DecisionCriteria):
- """ Stopping criteria based on Shannon entropy on the simplex
- Attr:
- lam(float): linear combination parameter between entropy and the
- speed term
- tau(float): decision threshold
- """
-
- def __init__(self, tau: float, lam: float):
+ """Criteria based on Shannon entropy and momentum.
+
+ This stopping criteria combines Shannon entropy on the simplex with
+ a momentum term.
+
+ Attributes:
+ lam: Linear combination parameter between entropy and speed term.
+ tau: Decision threshold.
+ """
+
+ def __init__(self, tau: float, lam: float) -> None:
+ """Initialize the momentum commit criteria.
+
+ Args:
+ tau: Decision threshold value.
+ lam: Linear combination parameter.
+ """
self.lam = lam
self.tau = tau
- def reset(self):
- pass
+ def decide(self, series: Dict[str, Any]) -> bool:
+ """Evaluate momentum-based stopping criteria.
- def decide(self, series):
- eps = np.power(.1, 6)
+ Args:
+ series: Dictionary containing series data.
+ Returns:
+ bool: True if stopping criteria met, False otherwise.
+ """
+ eps = np.power(.1, 6)
prob_history = copy(series['list_distribution'])
p = prob_history[-1]
@@ -176,60 +289,77 @@ def decide(self, series):
else:
stopping_rule = tmp
- d = stopping_rule < tau_
+ return stopping_rule < tau_
- return d
+class CriteriaEvaluator:
+ """Evaluates decision criteria for BCI task control.
-class CriteriaEvaluator():
- """Evaluates whether an series should commit to a decision based on the
- provided criteria.
+ This class manages multiple decision criteria to determine when to commit
+ to a decision based on the accumulated evidence.
- Parameters:
- -----------
- continue_criteria: list of criteria; if any of these evaluate to true the
- decision maker continues.
- commit_criteria: list of criteria; if any of these return true and
- continue_criteria are all false, decision maker commits to a decision.
+ Attributes:
+ continue_criteria: List of criteria that must all be false to allow
+ commitment.
+ commit_criteria: List of criteria where any true value triggers
+ commitment.
"""
- def __init__(self, continue_criteria: List[DecisionCriteria],
- commit_criteria: List[DecisionCriteria]):
+ def __init__(self,
+ continue_criteria: Optional[List[DecisionCriteria]] = None,
+ commit_criteria: Optional[List[DecisionCriteria]] = None) -> None:
+ """Initialize the criteria evaluator.
+
+ Args:
+ continue_criteria: List of criteria that must all be false to allow
+ commitment.
+ commit_criteria: List of criteria where any true value triggers
+ commitment.
+ """
self.continue_criteria = continue_criteria or []
self.commit_criteria = commit_criteria or []
@classmethod
- def default(cls, min_num_inq: int, max_num_inq: int, threshold: float):
- return cls(continue_criteria=[MinIterationsCriteria(min_num_inq)],
- commit_criteria=[
- MaxIterationsCriteria(max_num_inq),
- ProbThresholdCriteria(threshold)
- ])
-
- def do_series(self):
+ def default(cls, min_num_inq: int, max_num_inq: int,
+ threshold: float) -> 'CriteriaEvaluator':
+ """Create a default CriteriaEvaluator instance.
+
+ Args:
+ min_num_inq: Minimum number of inquiries required.
+ max_num_inq: Maximum number of inquiries allowed.
+ threshold: Probability threshold for commitment.
+
+ Returns:
+ CriteriaEvaluator: Configured with default criteria.
+ """
+ return cls(
+ continue_criteria=[MinIterationsCriteria(min_num_inq)],
+ commit_criteria=[
+ MaxIterationsCriteria(max_num_inq),
+ ProbThresholdCriteria(threshold)
+ ])
+
+ def do_series(self) -> None:
+ """Reset all criteria for a new series."""
for el_ in self.continue_criteria:
el_.reset()
for el in self.commit_criteria:
el.reset()
- def should_commit(self, series: Dict):
- """Evaluates the given series; returns true if stoppage criteria has
- been met, otherwise false.
-
- Parameters:
- -----------
- series - series data
- - target(str): target of the series
- - time_spent(ndarray[float]): |num_trials|x1
- time spent on the inquiry
- - list_sti(list[list[str]]): presented symbols in each
- inquiry
- - list_distribution(list[ndarray[float]]): list of |alp|x1
- arrays with prob. dist. over alp
- """
- if any(
- criteria.decide(series)
- for criteria in self.continue_criteria):
+ def should_commit(self, series: Dict[str, Any]) -> bool:
+ """Evaluate whether to commit to a decision.
+
+ Args:
+ series: Dictionary containing series data with:
+ target (str): Target of the series.
+ time_spent (np.ndarray): Time spent on each inquiry.
+ list_sti (List[List[str]]): Presented symbols in each inquiry.
+ list_distribution (List[np.ndarray]): Probability distributions
+ over the alphabet for each inquiry.
+
+ Returns:
+ bool: True if commitment criteria met, False otherwise.
+ """
+ if any(criteria.decide(series) for criteria in self.continue_criteria):
return False
- return any(
- criteria.decide(series) for criteria in self.commit_criteria)
+ return any(criteria.decide(series) for criteria in self.commit_criteria)
diff --git a/bcipy/task/control/evidence.py b/bcipy/task/control/evidence.py
index 87f6b658b..83c6b68c6 100644
--- a/bcipy/task/control/evidence.py
+++ b/bcipy/task/control/evidence.py
@@ -1,7 +1,13 @@
+"""Evidence evaluation module for BCI task control.
+
+This module provides classes and functions for extracting evidence from raw device data,
+including EEG, gaze tracking, and switch input data. The module supports different types
+of evidence evaluation based on the input data type and desired output evidence type.
+"""
+
# mypy: disable-error-code="override"
-"""Classes and functions for extracting evidence from raw device data."""
import logging
-from typing import List, Optional, Type
+from typing import Any, List, Optional, Type
import numpy as np
@@ -20,20 +26,33 @@
class EvidenceEvaluator:
- """Base class for a class that can evaluate raw device data using a
- signal_model. EvidenceEvaluators are responsible for performing necessary
- preprocessing steps such as filtering and reshaping.
+ """Base class for evaluating raw device data using a signal model.
- Parameters
- ----------
- symbol_set: set of possible symbols presented
- signal_model: model trained using a calibration session of the same user.
+ This class defines the interface for evidence evaluators, which are responsible
+ for performing necessary preprocessing steps such as filtering and reshaping
+ before evaluating the evidence.
+
+ Attributes:
+ symbol_set: List of possible symbols that can be presented.
+ signal_model: Model trained using calibration session data.
+ device_spec: Specification of the input device.
"""
- def __init__(self,
- symbol_set: List[str],
- signal_model: SignalModel,
- parameters: Optional[Parameters] = None):
+ def __init__(
+ self,
+ symbol_set: List[str],
+ signal_model: SignalModel,
+ parameters: Optional[Parameters] = None) -> None:
+ """Initialize the evidence evaluator.
+
+ Args:
+ symbol_set: List of possible symbols that can be presented.
+ signal_model: Model trained using calibration session data.
+ parameters: Optional configuration parameters.
+
+ Raises:
+ AssertionError: If signal model metadata is missing or incompatible.
+ """
assert signal_model.metadata, "Metadata missing from signal model."
device_spec = signal_model.metadata.device_spec
assert ContentType(
@@ -45,31 +64,63 @@ def __init__(self,
@property
def consumes(self) -> ContentType:
- """ContentType of the data that should be input"""
+ """Get the type of data this evaluator consumes.
+
+ Returns:
+ ContentType: Type of input data required.
+ """
+ raise NotImplementedError()
@property
def produces(self) -> EvidenceType:
- """Type of evidence that is output"""
+ """Get the type of evidence this evaluator produces.
- def evaluate(self, **kwargs):
- """Evaluate the evidence"""
+ Returns:
+ EvidenceType: Type of evidence output.
+ """
+ raise NotImplementedError()
+
+ def evaluate(self, **kwargs: Any) -> np.ndarray:
+ """Evaluate the evidence from raw data.
+
+ Args:
+ **kwargs: Arbitrary keyword arguments for evaluation.
+
+ Returns:
+ np.ndarray: Evaluated evidence data.
+ """
+ raise NotImplementedError()
class EEGEvaluator(EvidenceEvaluator):
- """EvidenceEvaluator that extracts symbol likelihoods from raw EEG data.
+ """Evidence evaluator for extracting symbol likelihoods from EEG data.
- Parameters
- ----------
- symbol_set: set of possible symbols presented
- signal_model: trained signal model
+ This evaluator processes raw EEG data to compute likelihood ratios for
+ different symbols based on the ERP response.
+
+ Attributes:
+ consumes: Type of input data (EEG).
+ produces: Type of evidence output (ERP).
+ channel_map: Mapping of EEG channels.
+ transform: Signal transformation function.
+ reshape: Trial reshaping function.
"""
+
consumes = ContentType.EEG
produces = EvidenceType.ERP
- def __init__(self,
- symbol_set: List[str],
- signal_model: SignalModel,
- parameters: Optional[Parameters] = None):
+ def __init__(
+ self,
+ symbol_set: List[str],
+ signal_model: SignalModel,
+ parameters: Optional[Parameters] = None) -> None:
+ """Initialize the EEG evaluator.
+
+ Args:
+ symbol_set: List of possible symbols that can be presented.
+ signal_model: Model trained using calibration session data.
+ parameters: Optional configuration parameters.
+ """
super().__init__(symbol_set, signal_model, parameters)
self.channel_map = analysis_channels(self.device_spec.channels,
@@ -77,47 +128,60 @@ def __init__(self,
self.transform = signal_model.metadata.transform
self.reshape = TrialReshaper()
- def preprocess(self, raw_data: np.ndarray, times: List[float],
- target_info: List[str], window_length: float) -> np.ndarray:
- """Preprocess the inquiry data.
-
- Parameters
- ----------
- raw_data - C x L eeg data where C is number of channels and L is the
- signal length
- symbols - symbols displayed in the inquiry
- times - timestamps associated with each symbol
- target_info - target information about the stimuli;
- ex. ['nontarget', 'nontarget', ...]
- window_length - The length of the time between stimuli presentation
+ def preprocess(
+ self,
+ raw_data: np.ndarray,
+ times: List[float],
+ target_info: List[str],
+ window_length: float) -> np.ndarray:
+ """Preprocess the inquiry EEG data.
+
+ Args:
+ raw_data: C x L EEG data where C is number of channels and L is
+ signal length.
+ times: Timestamps associated with each symbol.
+ target_info: Target information about the stimuli
+ (e.g. ['nontarget', 'nontarget', ...]).
+ window_length: Length of time between stimuli presentation.
+
+ Returns:
+ np.ndarray: Preprocessed EEG data.
"""
transformed_data, transform_sample_rate = self.transform(
raw_data, self.device_spec.sample_rate)
# The data from DAQ is assumed to have offsets applied
- reshaped_data, _lbls = self.reshape(trial_targetness_label=target_info,
- timing_info=times,
- eeg_data=transformed_data,
- sample_rate=transform_sample_rate,
- channel_map=self.channel_map,
- poststimulus_length=window_length)
+ reshaped_data, _lbls = self.reshape(
+ trial_targetness_label=target_info,
+ timing_info=times,
+ eeg_data=transformed_data,
+ sample_rate=transform_sample_rate,
+ channel_map=self.channel_map,
+ poststimulus_length=window_length)
return reshaped_data
- # pylint: disable=arguments-differ
- def evaluate(self, raw_data: np.ndarray, symbols: List[str],
- times: List[float], target_info: List[str],
- window_length: float, *args) -> np.ndarray:
- """Evaluate the evidence.
-
- Parameters
- ----------
- raw_data - C x L eeg data where C is number of channels and L is the
- signal length
- symbols - symbols displayed in the inquiry
- times - timestamps associated with each symbol
- target_info - target information about the stimuli;
- ex. ['nontarget', 'nontarget', ...]
- window_length - The length of the time between stimuli presentation
+ def evaluate(
+ self,
+ raw_data: np.ndarray,
+ symbols: List[str],
+ times: List[float],
+ target_info: List[str],
+ window_length: float,
+ *args: Any) -> np.ndarray:
+ """Evaluate EEG evidence.
+
+ Args:
+ raw_data: C x L EEG data where C is number of channels and L is
+ signal length.
+ symbols: Symbols displayed in the inquiry.
+ times: Timestamps associated with each symbol.
+ target_info: Target information about the stimuli
+ (e.g. ['nontarget', 'nontarget', ...]).
+ window_length: Length of time between stimuli presentation.
+ *args: Additional arguments.
+
+ Returns:
+ np.ndarray: Likelihood ratios for each symbol.
"""
data = self.preprocess(raw_data, times, target_info, window_length)
return self.signal_model.compute_likelihood_ratio(
@@ -125,20 +189,34 @@ def evaluate(self, raw_data: np.ndarray, symbols: List[str],
class GazeEvaluator(EvidenceEvaluator):
- """EvidenceEvaluator that extracts symbol likelihoods from raw gaze data.
+ """Evidence evaluator for extracting symbol likelihoods from gaze data.
- Parameters
- ----------
- symbol_set: set of possible symbols presented
- gaze_model: trained gaze model
+ This evaluator processes raw eye tracking data to compute likelihoods
+ for different symbols based on gaze patterns.
+
+ Attributes:
+ consumes: Type of input data (EYETRACKER).
+ produces: Type of evidence output (EYE).
+ channel_map: Mapping of eye tracking channels.
+ transform: Signal transformation function.
+ reshape: Gaze data reshaping function.
"""
+
consumes = ContentType.EYETRACKER
produces = EvidenceType.EYE
- def __init__(self,
- symbol_set: List[str],
- signal_model: SignalModel,
- parameters: Optional[Parameters] = None):
+ def __init__(
+ self,
+ symbol_set: List[str],
+ signal_model: SignalModel,
+ parameters: Optional[Parameters] = None) -> None:
+ """Initialize the gaze evaluator.
+
+ Args:
+ symbol_set: List of possible symbols that can be presented.
+ signal_model: Model trained using calibration session data.
+ parameters: Optional configuration parameters.
+ """
super().__init__(symbol_set, signal_model, parameters)
self.channel_map = analysis_channels(self.device_spec.channels,
@@ -146,25 +224,27 @@ def __init__(self,
self.transform = signal_model.metadata.transform
self.reshape = GazeReshaper()
- def preprocess(self, raw_data: np.ndarray, times: List[float],
- flash_time: float) -> np.ndarray:
- """Preprocess the inquiry data.
+ def preprocess(
+ self,
+ raw_data: np.ndarray,
+ times: List[float],
+ flash_time: float) -> np.ndarray:
+ """Preprocess the inquiry gaze data.
- Parameters
- ----------
- raw_data - C x L eeg data where C is number of channels and L is the
- signal length. Includes all channels in devices.json
- symbols - symbols displayed in the inquiry
- times - timestamps associated with each symbol
- flash_time - duration (in seconds) of each stimulus
-
- Function
- --------
The preprocessing is functionally different than Gaze Reshaper, since
the raw data contains only one inquiry. start_idx is determined as the
start time of first symbol flashing multiplied by the sampling rate
of eye tracker. stop_idx is the index indicating the end of last
symbol flashing.
+
+ Args:
+ raw_data: C x L data where C is number of channels and L is signal
+ length. Includes all channels in devices.json.
+ times: Timestamps associated with each symbol.
+ flash_time: Duration (in seconds) of each stimulus.
+
+ Returns:
+ np.ndarray: Preprocessed gaze data (4, N_samples).
"""
if self.transform:
transformed_data, transform_sample_rate = self.transform(
@@ -174,55 +254,85 @@ def preprocess(self, raw_data: np.ndarray, times: List[float],
transform_sample_rate = self.device_spec.sample_rate
start_idx = int(self.device_spec.sample_rate * times[0])
- stop_idx = start_idx + int((times[-1] - times[0] + flash_time) * self.device_spec.sample_rate)
+ stop_idx = start_idx + int(
+ (times[-1] - times[0] + flash_time) * self.device_spec.sample_rate)
data_all_channels = transformed_data[:, start_idx:stop_idx]
# Extract left and right eye from all channels. Remove/replace nan values
left_eye, right_eye, _, _, _, _ = extract_eye_info(data_all_channels)
- reshaped_data = np.vstack((np.array(left_eye).T, np.array(right_eye).T))
-
- return reshaped_data # (4, N_samples)
+ reshaped_data = np.vstack(
+ (np.array(left_eye).T, np.array(right_eye).T))
- # pylint: disable=arguments-differ
- def evaluate(self, raw_data: np.ndarray, symbols: List[str],
- times: List[float], target_info: List[str],
- window_length: float, flash_time: float,
- stim_length: float) -> np.ndarray:
- """Evaluate the evidence.
+ return reshaped_data
- Parameters
- ----------
- raw_data - C x L eeg data where C is number of channels and L is the
- signal length
- symbols - symbols displayed in the inquiry
- times - timestamps associated with each symbol
- target_info - target information about the stimuli;
- ex. ['nontarget', 'nontarget', ...]
- window_length - The length of the time between stimuli presentation
+ def evaluate(
+ self,
+ raw_data: np.ndarray,
+ symbols: List[str],
+ times: List[float],
+ target_info: List[str],
+ window_length: float,
+ flash_time: float,
+ stim_length: float) -> np.ndarray:
+ """Evaluate gaze evidence.
+
+ Args:
+ raw_data: C x L data where C is number of channels and L is signal
+ length.
+ symbols: Symbols displayed in the inquiry.
+ times: Timestamps associated with each symbol.
+ target_info: Target information about the stimuli.
+ window_length: Length of time between stimuli presentation.
+ flash_time: Duration of each stimulus.
+ stim_length: Length of stimulus sequence.
+
+ Returns:
+ np.ndarray: Likelihood values for each symbol.
"""
data = self.preprocess(raw_data, times, flash_time)
- # We need the likelihoods in the form of p(label | gaze). predict returns the argmax of the likelihoods.
+ # We need the likelihoods in the form of p(label | gaze).
+ # predict returns the argmax of the likelihoods.
# Therefore we need predict_proba method to get the likelihoods.
- likelihood = self.signal_model.evaluate_likelihood(data, symbols, self.symbol_set)
+ likelihood = self.signal_model.evaluate_likelihood(
+ data, symbols, self.symbol_set)
return likelihood
class SwitchEvaluator(EvidenceEvaluator):
- """EvidenceEvaluator that extracts symbol likelihoods from raw Switch data.
+ """Evidence evaluator for extracting symbol likelihoods from switch data.
- Parameters
- ----------
- symbol_set: set of possible symbols presented
- signal_model: trained signal model
+ This evaluator processes raw switch input data to compute likelihoods
+ for different symbols based on button press patterns.
+
+ Attributes:
+ consumes: Type of input data (MARKERS).
+ produces: Type of evidence output (BTN).
+ button_press_mode: Mode of button press interpretation.
+ trial_count: Number of trials in stimulus sequence.
"""
+
consumes = ContentType.MARKERS
produces = EvidenceType.BTN
- def __init__(self,
- symbol_set: List[str],
- signal_model: SignalModel,
- parameters: Optional[Parameters] = None):
+ def __init__(
+ self,
+ symbol_set: List[str],
+ signal_model: SignalModel,
+ parameters: Optional[Parameters] = None) -> None:
+ """Initialize the switch evaluator.
+
+ Args:
+ symbol_set: List of possible symbols that can be presented.
+ signal_model: Model trained using calibration session data.
+ parameters: Optional configuration parameters.
+
+ Raises:
+ AssertionError: If button press mode is not supported.
+ """
super().__init__(symbol_set, signal_model, parameters)
+ if not parameters:
+ raise ValueError("Parameters required for SwitchEvaluator")
+
self.button_press_mode = ButtonPressMode(
parameters.get('preview_inquiry_progress_method'))
self.trial_count = parameters.get('stim_length')
@@ -233,12 +343,25 @@ def __init__(self,
"To run without button press evidence set the acq_mode to exclude MARKERS."
))
- def preprocess(self, raw_data: np.ndarray, times: List[float],
- target_info: List[str], window_length: float) -> np.ndarray:
- """Preprocess the inquiry data.
+ def preprocess(
+ self,
+ raw_data: np.ndarray,
+ times: List[float],
+ target_info: List[str],
+ window_length: float) -> np.ndarray:
+ """Preprocess the inquiry switch data.
Determines the return data based on whether the switch was pressed
during the inquiry and the configured ButtonPressMode.
+
+ Args:
+ raw_data: Switch input data.
+ times: Timestamps associated with each symbol.
+ target_info: Target information about the stimuli.
+ window_length: Length of time between stimuli presentation.
+
+ Returns:
+ np.ndarray: Preprocessed switch data.
"""
switch_was_pressed = np.any(raw_data)
@@ -305,8 +428,7 @@ def get_evaluator(
def find_matching_evaluator(
signal_model: SignalModel) -> Type[EvidenceEvaluator]:
- """Find the first EvidenceEvaluator compatible with the given signal
- model."""
+ """Find the first EvidenceEvaluator compatible with the given signal model."""
content_type = ContentType(signal_model.metadata.device_spec.content_type)
# Metadata may provide an EvidenceType with a model so the same data source can
# be used to produce multiple types of evidence (ex. alpha)
@@ -317,7 +439,6 @@ def find_matching_evaluator(
evidence_type = EvidenceType(model_output.upper())
except ValueError:
log.error(f"Unsupported evidence type: {model_output}")
-
return get_evaluator(content_type, evidence_type)
@@ -325,7 +446,6 @@ def init_evidence_evaluator(
symbol_set: List[str],
signal_model: SignalModel,
parameters: Optional[Parameters] = None) -> EvidenceEvaluator:
- """Find an EvidenceEvaluator that matches the given signal_model and
- initialize it."""
+ """Find an EvidenceEvaluator that matches the given signal_model and initialize it."""
evaluator_class = find_matching_evaluator(signal_model)
return evaluator_class(symbol_set, signal_model, parameters)
diff --git a/bcipy/task/control/handler.py b/bcipy/task/control/handler.py
index 6dd534a80..dd87987ab 100644
--- a/bcipy/task/control/handler.py
+++ b/bcipy/task/control/handler.py
@@ -1,3 +1,10 @@
+"""Task control handler module for BCI tasks.
+
+This module provides classes for managing decision making and evidence fusion
+in BCI tasks. It includes functionality for scheduling inquiries, managing
+task state, and making decisions based on accumulated evidence.
+"""
+
import logging
import string
from typing import Dict, List, Optional, Tuple
@@ -14,39 +21,55 @@
log = logging.getLogger(SESSION_LOG_FILENAME)
-class EvidenceFusion():
- """ Fuses likelihood evidences provided by the inference
- Attr:
- evidence_history(dict{list[ndarray]}): Dictionary of difference
- evidence types in list. Lists are ordered using the arrival
- time.
- likelihood(ndarray[]): current probability distribution over the
- set. Gets updated once new evidence arrives. """
+class EvidenceFusion:
+ """Class for fusing likelihood evidence from multiple sources.
- def __init__(self, list_name_evidence, len_dist):
- self.evidence_history = {name: [] for name in list_name_evidence}
- self.likelihood = np.ones(len_dist) / len_dist
+ This class manages the combination of evidence from different sources
+ (e.g., EEG, eye tracking) to compute a final probability distribution
+ over possible decisions.
+
+ Attributes:
+ evidence_history: Dictionary mapping evidence types to their history.
+ likelihood: Current probability distribution over the decision space.
+ """
+
+ def __init__(self, list_name_evidence: List[EvidenceType],
+ len_dist: int) -> None:
+ """Initialize the evidence fusion system.
- def update_and_fuse(self, dict_evidence):
- """ Updates the probability distribution
- Args:
- dict_evidence(dict{name: ndarray[float]}): dictionary of
- evidences (EEG (likelihood ratios) and other likelihoods)
+ Args:
+ list_name_evidence: List of evidence types to track.
+ len_dist: Length of the probability distribution (number of
+ possible decisions).
"""
- # {ERP: [], EYE: ()}
+ self.evidence_history: Dict[EvidenceType, List[np.ndarray]] = {
+ name: [] for name in list_name_evidence
+ }
+ self.likelihood = np.ones(len_dist) / len_dist
+
+ def update_and_fuse(self,
+ dict_evidence: Dict[EvidenceType,
+ np.ndarray]) -> np.ndarray:
+ """Update and fuse probability distributions with new evidence.
+
+ Args:
+ dict_evidence: Dictionary mapping evidence types to their
+ likelihood arrays.
- for key in dict_evidence.keys():
+ Returns:
+ np.ndarray: Updated probability distribution after fusion.
+ """
+ for key in dict_evidence:
tmp = dict_evidence[key][:][:]
self.evidence_history[key].append(tmp)
- # Current rule is to multiply
+ # Current fusion rule is multiplication
for value in dict_evidence.values():
self.likelihood *= value[:]
if np.isinf(np.sum(self.likelihood)):
tmp = np.zeros(len(self.likelihood))
tmp[np.where(self.likelihood == np.inf)[0][0]] = 1
-
self.likelihood = tmp
if not np.isnan(np.sum(self.likelihood)):
@@ -56,24 +79,26 @@ def update_and_fuse(self, dict_evidence):
return likelihood
- def reset_history(self):
- """ Clears evidence history """
+ def reset_history(self) -> None:
+ """Clears evidence history."""
for value in self.evidence_history.values():
del value[:]
self.likelihood = np.ones(len(self.likelihood)) / len(self.likelihood)
def save_history(self) -> None:
- """ Saves the current likelihood history """
+ """Save the current likelihood history.
+
+ Note:
+ Not currently implemented.
+ """
log.warning('save_history not implemented')
- return
@property
def latest_evidence(self) -> Dict[EvidenceType, List[float]]:
- """Latest evidence of each type in the evidence history.
+ """Get the latest evidence of each type.
- Returns
- -------
- a dictionary with an entry for all configured evidence types.
+ Returns:
+ Dict mapping evidence types to their most recent values.
"""
return {
name: list(evidence[-1]) if evidence else []
@@ -82,57 +107,54 @@ def latest_evidence(self) -> Dict[EvidenceType, List[float]]:
class DecisionMaker:
- """ Scheduler of the entire framework
- Attr:
- state(str): state of the framework, which increases in size
- by 1 after each inquiry. Elements are alphabet, ".,_,<"
- where ".": null_inquiry(no decision made)
- "_": space bar
- "<": back space
- alphabet(list[str]): list of symbols used by the framework. Can
- be switched with location of images or one hot encoded images.
- is_txt_stim(bool): whether the stimuli are text or images
- inq_constants(list[str]): optional list of letters which should appear in
- every inquiry.
- stopping_evaluator: CriteriaEvaluator - optional parameter to
- provide alternative rules for committing to a decision.
- stimuli_agent(StimuliAgent): the query selection mechanism of the
- system
- stimuli_timing(list[float]): list of timings for the stimuli ([fixation_time, stimuli_flash_time])
- stimuli_order(StimuliOrder): ordering of the stimuli (random, distributed)
- stimuli_jitter(float): jitter of the inquiry stimuli in seconds
-
- Functions:
- decide():
- Checks the criteria for making and series, using all
- evidences and decides to do an series or to collect more
- evidence
- do_series():
- Once committed an series perform updates to condition the
- distribution on the previous letter.
- schedule_inquiry():
- schedule the next inquiry using the current information
- decide_state_update():
- If committed to an series update the state using a decision
- metric.
- (e.g. pick the letter with highest likelihood)
- prepare_stimuli():
- prepares the query set for the next inquiry
- (e.g pick n-highest likely letters and randomly shuffle)
+ """Scheduler and decision maker for BCI task control.
+
+ This class manages the scheduling of inquiries and decision making based
+ on accumulated evidence. It maintains the task state and coordinates
+ the interaction between evidence collection and decision making.
+
+ Attributes:
+ state: Current state string, growing by 1 after each inquiry.
+ displayed_state: State formatted for display.
+ alphabet: List of possible symbols.
+ is_txt_stim: Whether stimuli are text or images.
+ stimuli_timing: Timing parameters for stimuli presentation.
+ stimuli_order: Order of stimuli presentation.
+ stimuli_jitter: Jitter in stimulus timing.
+ inq_constants: Symbols to include in every inquiry.
+ stopping_evaluator: Evaluator for stopping criteria.
+ stimuli_agent: Agent for selecting stimuli.
+ list_series: List of series data.
+ time: Current time.
+ inquiry_counter: Number of inquiries made.
+ last_selection: Last selected symbol.
+ """
+
+ def __init__(
+ self,
+ state: str = '',
+ alphabet: List[str] = list(string.ascii_uppercase) + [BACKSPACE_CHAR] +
+ [SPACE_CHAR],
+ is_txt_stim: bool = True,
+ stimuli_timing: List[float] = [1, .2],
+ stimuli_jitter: float = 0,
+ stimuli_order: StimuliOrder = StimuliOrder.RANDOM,
+ inq_constants: Optional[List[str]] = None,
+ stopping_evaluator: Optional[CriteriaEvaluator] = None,
+ stimuli_agent: Optional[StimuliAgent] = None) -> None:
+ """Initialize the decision maker.
+
+ Args:
+ state: Initial state string.
+ alphabet: List of possible symbols.
+ is_txt_stim: Whether stimuli are text or images.
+ stimuli_timing: [fixation_time, stimuli_flash_time].
+ stimuli_jitter: Jitter in stimulus timing (seconds).
+ stimuli_order: Order of stimuli presentation.
+ inq_constants: Symbols to include in every inquiry.
+ stopping_evaluator: Evaluator for stopping criteria.
+ stimuli_agent: Agent for selecting stimuli.
"""
-
- def __init__(self,
- state: str = '',
- alphabet: List[str] = list(string.ascii_uppercase) + [BACKSPACE_CHAR] + [SPACE_CHAR],
- is_txt_stim: bool = True,
- stimuli_timing: List[float] = [1, .2],
- stimuli_jitter: float = 0,
- stimuli_order: StimuliOrder = StimuliOrder.RANDOM,
- inq_constants: Optional[List[str]] = None,
- stopping_evaluator: CriteriaEvaluator = CriteriaEvaluator.default(min_num_inq=2,
- max_num_inq=10,
- threshold=0.8),
- stimuli_agent: Optional[StimuliAgent] = None):
self.state = state
self.displayed_state = self.form_display_state(state)
self.stimuli_timing = stimuli_timing
@@ -142,75 +164,90 @@ def __init__(self,
self.alphabet = alphabet
self.is_txt_stim = is_txt_stim
- self.list_series = [{'target': None, 'time_spent': 0,
- 'list_sti': [], 'list_distribution': [],
- 'decision': None}]
+ self.list_series = [{
+ 'target': None,
+ 'time_spent': 0,
+ 'list_sti': [],
+ 'list_distribution': [],
+ 'decision': None
+ }]
self.time = 0
self.inquiry_counter = 0
self.stopping_evaluator = stopping_evaluator
- self.stimuli_agent = stimuli_agent or RandomStimuliAgent(alphabet=self.alphabet)
+ self.stimuli_agent = stimuli_agent or RandomStimuliAgent(
+ alphabet=self.alphabet)
self.last_selection = ''
# Items shown in every inquiry
self.inq_constants = inq_constants
- def reset(self, state=''):
- """ Resets the decision maker with the initial state
- Args:
- state(str): current state of the system """
+ def reset(self, state: str = '') -> None:
+ """Reset the decision maker to initial state.
+
+ Args:
+ state: New initial state string.
+ """
self.state = state
self.displayed_state = self.form_display_state(self.state)
- self.list_series = [{'target': None, 'time_spent': 0,
- 'list_sti': [], 'list_distribution': []}]
+ self.list_series = [{
+ 'target': None,
+ 'time_spent': 0,
+ 'list_sti': [],
+ 'list_distribution': []
+ }]
self.time = 0
self.inquiry_counter = 0
self.stimuli_agent.reset()
- def form_display_state(self, state):
- """ Forms the state information or the user that fits to the
- display. Basically takes '.' and BACKSPACE_CHAR into consideration and rewrites
- the state
- Args:
- state(str): state string
- Return:
- displayed_state(str): state without '<,.' and removes
- backspaced letters """
+ def form_display_state(self, state: str) -> str:
+ """Format state string for display.
+
+ Processes special characters (backspace, dots) and formats the
+ state appropriately for display.
+
+ Args:
+ state: Raw state string.
+
+ Returns:
+ str: Formatted state string for display.
+ """
tmp = ''
for i in state:
if i == BACKSPACE_CHAR:
tmp = tmp[0:-1]
elif i != '.':
tmp += i
-
return tmp
- def update(self, state=''):
+ def update(self, state: str = '') -> None:
+ """Update the current state.
+
+ Args:
+ state: New state string.
+ """
self.state = state
self.displayed_state = self.form_display_state(state)
- def decide(self, p) -> Tuple[bool, InquirySchedule]:
- """ Once evidence is collected, decision_maker makes a decision to
- stop or not by leveraging the information of the stopping
- criteria. Can decide to do an series or schedule another inquiry.
-
- Args
- ----
- p(ndarray[float]): |A| x 1 distribution array
- |A|: cardinality of the alphabet
-
- Return
- ------
- - commitment: True if a letter is a commitment is made
- False if requires more evidence
- - inquiry schedule: Extra arguments depending on the decision
- """
+ def decide(self, p: np.ndarray) -> Tuple[bool, Optional[InquirySchedule]]:
+ """Make a decision based on current evidence.
+
+ Evaluates whether to commit to a decision or schedule another
+ inquiry based on the current probability distribution and
+ stopping criteria.
+
+ Args:
+ p: Probability distribution over possible decisions.
+ Returns:
+ Tuple containing:
+ - bool: True if committing to a decision.
+ - Optional[InquirySchedule]: Schedule for next inquiry if needed.
+ """
self.list_series[-1]['list_distribution'].append(p[:])
- # Check stopping criteria
if self.stopping_evaluator.should_commit(self.list_series[-1]):
self.do_series()
return True, None
@@ -218,34 +255,46 @@ def decide(self, p) -> Tuple[bool, InquirySchedule]:
stimuli = self.schedule_inquiry()
return False, stimuli
- def do_series(self):
- """ series refers to a commitment to a decision.
- If made, state is updated, displayed state is updated
- a new series is appended. """
+ def do_series(self) -> None:
+ """Handle commitment to a decision.
+
+ Updates state and prepares for the next series when a decision
+ is made.
+ """
self.inquiry_counter = 0
decision = self.decide_state_update()
self.last_selection = decision
self.state += decision
self.displayed_state = self.form_display_state(self.state)
- # Initialize next series
- self.list_series.append({'target': None, 'time_spent': 0,
- 'list_sti': [], 'list_distribution': []})
+ self.list_series.append({
+ 'target': None,
+ 'time_spent': 0,
+ 'list_sti': [],
+ 'list_distribution': []
+ })
self.stimuli_agent.do_series()
self.stopping_evaluator.do_series()
def schedule_inquiry(self) -> InquirySchedule:
- """ Schedules next inquiry """
+ """Schedule the next inquiry.
+
+ Returns:
+ InquirySchedule: Schedule for the next inquiry.
+ """
self.state += '.'
stimuli = self.prepare_stimuli()
self.list_series[-1]['list_sti'].append(stimuli[0])
self.inquiry_counter += 1
-
return stimuli
- def decide_state_update(self):
- """ Checks stopping criteria to commit to an series """
+ def decide_state_update(self) -> str:
+ """Determine the next state update.
+
+ Returns:
+ str: Selected symbol for state update.
+ """
idx = np.where(
self.list_series[-1]['list_distribution'][-1] ==
np.max(self.list_series[-1]['list_distribution'][-1]))[0][0]
@@ -254,13 +303,10 @@ def decide_state_update(self):
return decision
def prepare_stimuli(self) -> InquirySchedule:
- """ Given the alphabet, under a rule, prepares a stimuli for
- the next inquiry.
+ """Prepare stimuli for the next inquiry.
- Return
- ------
- stimuli(tuple[list[str],list[float],list[str]]): tuple of
- stimuli information. [0]: letter, [1]: timing, [2]: color
+ Returns:
+ InquirySchedule: Schedule containing stimuli and timing information.
"""
# querying agent decides on possible letters to be shown on the screen
diff --git a/bcipy/task/control/query.py b/bcipy/task/control/query.py
index 3e8b3c867..af7d478be 100644
--- a/bcipy/task/control/query.py
+++ b/bcipy/task/control/query.py
@@ -1,6 +1,14 @@
+"""Query module for BCI task control.
+
+This module provides classes for managing stimulus presentation in BCI tasks.
+It includes agents that determine which stimuli to present based on different
+selection strategies, such as random selection or N-best selection based on
+probability distributions.
+"""
+
import random
from abc import ABC, abstractmethod
-from typing import List, Optional
+from typing import Any, List, Optional
import numpy as np
@@ -8,48 +16,81 @@
class StimuliAgent(ABC):
+ """Abstract base class for stimulus selection agents.
+
+ This class defines the interface for agents that select stimuli to present
+ during BCI tasks. Subclasses implement different selection strategies.
+ """
+
@abstractmethod
- def reset(self):
+ def reset(self) -> None:
+ """Reset the agent's state."""
...
@abstractmethod
- def return_stimuli(self, list_distribution: np.ndarray, **kwargs):
- """ updates the agent with most likely posterior and selects queries
- Args:
- list_distribution(list[ndarray]): posterior distributions as
- stored in the decision maker
- Return:
- query(list[str]): queries """
+ def return_stimuli(self, list_distribution: np.ndarray,
+ **kwargs: Any) -> List[str]:
+ """Update agent with posterior probabilities and select queries.
+
+ Args:
+ list_distribution: List of posterior probability distributions.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ List[str]: Selected stimuli for the next query.
+ """
...
@abstractmethod
- def do_series(self):
- """ If the system decides on a class let the agent know about it """
+ def do_series(self) -> None:
+ """Handle series completion.
+
+ Called when the system decides on a class to let the agent update
+ its state accordingly.
+ """
...
class RandomStimuliAgent(StimuliAgent):
- """ An inherited class of StimuliAgent. Chooses random set of letters for
- queries instead of most likely letters.
- Attr:
- alphabet(list[str]): Query space(possible queries).
- len_query(int): number of elements in a query
- Functions:
- reset(): reset the agent
- return_stimuli(): update the agent and return a stimuli set
- do_series(): one a commitment is made update agent
- """
-
- def __init__(self, alphabet: List[str], len_query: int = 4):
+ """Stimuli agent that randomly selects queries.
+
+ This agent chooses random sets of letters for queries instead of using
+ probability-based selection.
+
+ Attributes:
+ alphabet: List of possible symbols to query from.
+ len_query: Number of symbols to include in each query.
+ """
+
+ def __init__(self, alphabet: List[str], len_query: int = 4) -> None:
+ """Initialize the random stimuli agent.
+
+ Args:
+ alphabet: List of possible symbols to query from.
+ len_query: Number of symbols to include in each query.
+ """
self.alphabet = alphabet
self.len_query = len_query
- def reset(self):
- """ This querying method is memoryless no reset needed """
+ def reset(self) -> None:
+ """Reset the agent's state.
+
+ This querying method is memoryless, so no reset is needed.
+ """
pass
- def return_stimuli(self, list_distribution: np.ndarray, constants: Optional[List[str]] = None):
- """ return random elements from the alphabet """
+ def return_stimuli(self,
+ list_distribution: np.ndarray,
+ constants: Optional[List[str]] = None) -> List[str]:
+ """Return random elements from the alphabet.
+
+ Args:
+ list_distribution: List of probability distributions (unused).
+ constants: Optional list of symbols to always include in the result.
+
+ Returns:
+ List[str]: Randomly selected symbols, with constants if provided.
+ """
tmp = [i for i in self.alphabet]
query = random.sample(tmp, self.len_query)
@@ -58,41 +99,58 @@ def return_stimuli(self, list_distribution: np.ndarray, constants: Optional[List
return query
- def do_series(self):
+ def do_series(self) -> None:
+ """Handle series completion.
+
+ This agent is stateless, so no action is needed.
+ """
pass
class NBestStimuliAgent(StimuliAgent):
- """ An inherited class of StimuliAgent. Updates the agent with N most likely
- letters based on posteriors and selects queries.
- Attr:
- alphabet(list[str]): Query space(possible queries).
- len_query(int): number of elements in a query
- Functions:
- reset(): reset the agent
- return_stimuli(): update the agent and return a stimuli set
- do_series(): one a commitment is made update agent
- """
-
- def __init__(self, alphabet: List[str], len_query: int = 4):
+ """Stimuli agent that selects the N most likely symbols.
+
+ This agent updates its selection based on posterior probabilities,
+ choosing the N symbols with highest probability for each query.
+
+ Attributes:
+ alphabet: List of possible symbols to query from.
+ len_query: Number of symbols to include in each query.
+ """
+
+ def __init__(self, alphabet: List[str], len_query: int = 4) -> None:
+ """Initialize the N-best stimuli agent.
+
+ Args:
+ alphabet: List of possible symbols to query from.
+ len_query: Number of symbols to include in each query.
+ """
self.alphabet = alphabet
self.len_query = len_query
- def reset(self):
+ def reset(self) -> None:
+ """Reset the agent's state.
+
+ This agent is stateless, so no reset is needed.
+ """
pass
def return_stimuli(self,
list_distribution: np.ndarray,
constants: Optional[List[str]] = None) -> List[str]:
- """Returns a list of the n most likely symbols based on the provided
- probabilities, where n is self.len_query. Symbols of the same
- probability will be ordered randomly.
-
- Parameters
- ----------
- list_distribution - list of lists of probabilities. Only the last list will
- be used.
- constants - optional list of symbols which should appear every result
+ """Return the N most likely symbols based on probabilities.
+
+ Selects symbols based on their probabilities in the distribution,
+ where N is self.len_query. Symbols with equal probabilities are
+ ordered randomly.
+
+ Args:
+ list_distribution: List of probability distributions. Only the
+ last distribution is used.
+ constants: Optional list of symbols to always include in the result.
+
+ Returns:
+ List[str]: Selected symbols, with constants if provided.
"""
symbol_probs = list(zip(self.alphabet, list_distribution[-1]))
randomized = random.sample(symbol_probs, len(symbol_probs))
@@ -102,5 +160,9 @@ def return_stimuli(self,
len_query=self.len_query,
always_included=constants)
- def do_series(self):
+ def do_series(self) -> None:
+ """Handle series completion.
+
+ This agent is stateless, so no action is needed.
+ """
pass
diff --git a/bcipy/task/data.py b/bcipy/task/data.py
index 51a1df069..aed8a44e7 100644
--- a/bcipy/task/data.py
+++ b/bcipy/task/data.py
@@ -1,4 +1,8 @@
-"""Module for functionality related to session-related data."""
+"""Module for functionality related to session-related data.
+
+This module provides classes and functions for managing BCI session data,
+including evidence types, inquiries, and session management.
+"""
from collections import Counter
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
@@ -9,15 +13,25 @@
def rounded(values: List[float], precision: int) -> List[float]:
"""Round the list of values to the given precision.
- Parameters
- ----------
- values - values to round
+ Args:
+ values: Values to round.
+ precision: Number of decimal places to round to.
+
+ Returns:
+ List[float]: Rounded values.
"""
return [round(value, precision) for value in values]
class EvidenceType(Enum):
- """Enum of the supported evidence types used in the various spelling tasks."""
+ """Enum of the supported evidence types used in the various spelling tasks.
+
+ Attributes:
+ LM: Language Model evidence.
+ ERP: Event-Related Potential using EEG signals.
+ BTN: Button press evidence.
+ EYE: Eye tracker evidence.
+ """
LM = 'LM' # Language Model
ERP = 'ERP' # Event-Related Potential using EEG signals
BTN = 'BTN' # Button
@@ -25,16 +39,22 @@ class EvidenceType(Enum):
@classmethod
def list(cls) -> List[str]:
- """List of evidence types"""
+ """List of evidence types.
+
+ Returns:
+ List[str]: List of evidence type names.
+ """
return [ev_type.name for ev_type in cls]
@classmethod
def deserialized(cls, serialized_name: str) -> 'EvidenceType':
"""Deserialized name of the given evidence type.
- Parameters:
- evidence_name - ex. 'lm_evidence'
+
+ Args:
+ serialized_name: Evidence name (ex. 'lm_evidence').
+
Returns:
- deserialized value: ex. EvidenceType.LM
+ EvidenceType: Deserialized value (ex. EvidenceType.LM).
"""
if serialized_name == 'eeg_evidence':
return EvidenceType.ERP
@@ -45,7 +65,11 @@ def __str__(self) -> str:
@property
def serialized(self) -> str:
- """Name used when serialized to a json file."""
+ """Name used when serialized to a json file.
+
+ Returns:
+ str: Serialized name of the evidence type.
+ """
if self == EvidenceType.ERP:
return 'eeg_evidence'
return f'{self.name.lower()}{EVIDENCE_SUFFIX}'
@@ -54,19 +78,18 @@ def serialized(self) -> str:
class Inquiry:
"""Represents a sequence of stimuli.
- Parameters:
- ----------
- stimuli - list of stimuli presented (letters, icons, etc)
- timing - duration in seconds for each stimulus
- target_info - targetness ('nontarget', 'target', etc) for each stimulus
- target_letter - current letter that the user is attempting to spell
- current_text - letters spelled so far
- target_text - word or words the user is attempting to spell
- next_display_state - text to be displayed after evaluating the current evidence
- lm_evidence - language model evidence for each stimulus
- eeg_evidence - eeg evidence for each stimulus
- likelihood - combined likelihood for each stimulus
- task_data - task-specific information about the inquiry that may be useful in training a model
+ Args:
+ stimuli: List of stimuli presented (letters, icons, etc).
+ timing: Duration in seconds for each stimulus.
+ triggers: List of (trigger_name, timestamp) tuples.
+ target_info: Targetness ('nontarget', 'target', etc) for each stimulus.
+ target_letter: Current letter that the user is attempting to spell.
+ current_text: Letters spelled so far.
+ target_text: Word or words the user is attempting to spell.
+ selection: Currently selected symbol.
+ next_display_state: Text to be displayed after evaluating evidence.
+ likelihood: Combined likelihood for each stimulus.
+ task_data: Task-specific information about the inquiry.
"""
def __init__(self,
@@ -80,7 +103,7 @@ def __init__(self,
selection: Optional[str] = None,
next_display_state: Optional[str] = None,
likelihood: Optional[List[float]] = None,
- task_data: Optional[Dict] = None) -> None:
+ task_data: Optional[Dict[str, Any]] = None) -> None:
super().__init__()
self.stimuli = stimuli
self.timing = timing
@@ -100,37 +123,53 @@ def __init__(self,
@property
def lm_evidence(self) -> List[float]:
- """Language model evidence"""
+ """Language model evidence.
+
+ Returns:
+ List[float]: Language model evidence values.
+ """
return self.evidences.get(EvidenceType.LM, [])
@property
def eeg_evidence(self) -> List[float]:
- """EEG evidence"""
+ """EEG evidence.
+
+ Returns:
+ List[float]: EEG evidence values.
+ """
return self.evidences.get(EvidenceType.ERP, [])
@property
def decision_made(self) -> bool:
- """Returns true if the result of the inquiry was a decision."""
+ """Returns true if the result of the inquiry was a decision.
+
+ Returns:
+ bool: True if a decision was made.
+ """
return self.current_text != self.next_display_state
@property
def is_correct_decision(self) -> bool:
- """Indicates whether the current selection was the target"""
+ """Indicates whether the current selection was the target.
+
+ Returns:
+ bool: True if selection matches target_letter.
+ """
if self.selection and self.target_letter:
return self.selection == self.target_letter
return False
@classmethod
- def from_dict(cls, data: dict) -> 'Inquiry':
- """Deserializes from a dict
+ def from_dict(cls, data: Dict[str, Any]) -> 'Inquiry':
+ """Deserializes from a dict.
- Parameters:
- ----------
- data - a dict in the format of the data output by the as_dict
- method.
+ Args:
+ data: A dict in the format of the data output by the as_dict method.
+
+ Returns:
+ Inquiry: New instance created from dict data.
"""
# partition into evidence data and other data.
-
evidences = {
EvidenceType.deserialized(name): value
for name, value in data.items() if name.endswith(EVIDENCE_SUFFIX)
@@ -148,9 +187,13 @@ def from_dict(cls, data: dict) -> 'Inquiry':
inquiry.evidences = evidences
return inquiry
- def as_dict(self) -> Dict:
- """Dict representation"""
- data: Dict = {
+ def as_dict(self) -> Dict[str, Any]:
+ """Dict representation.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing inquiry data.
+ """
+ data: Dict[str, Any] = {
'stimuli': self.stimuli,
'timing': self.timing,
'triggers': self.triggers,
@@ -174,15 +217,18 @@ def as_dict(self) -> Dict:
def stim_evidence(self,
symbol_set: List[str],
n_most_likely: int = 5) -> Dict[str, Any]:
- """Returns a dict of stim sequence data useful for debugging. Evidences
- are paired with the appropriate symbol for easier visual
+ """Returns a dict of stim sequence data useful for debugging.
+
+ Evidences are paired with the appropriate symbol for easier visual
scanning. Also, an additional attribute is provided to display the
top n most likely symbols based on the current evidence.
- Parameters:
- -----------
- symbol_set - list of stim in the same order as the evidences.
- n_most_likely - number of most likely elements to include
+ Args:
+ symbol_set: List of stim in the same order as the evidences.
+ n_most_likely: Number of most likely elements to include.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing stimulus evidence data.
"""
likelihood = dict(zip(symbol_set, self.format(self.likelihood)))
data: Dict[str, Any] = {
@@ -199,9 +245,11 @@ def stim_evidence(self,
def format(self, evidence: List[float]) -> List[float]:
"""Format the evidence for output.
- Parameters
- ----------
- evidence - list of evidence values
+ Args:
+ evidence: List of evidence values.
+
+ Returns:
+ List[float]: Formatted evidence values.
"""
if self.precision:
return rounded(evidence, self.precision)
@@ -209,7 +257,16 @@ def format(self, evidence: List[float]) -> List[float]:
class Session:
- """Represents a data collection session. Not all tasks record session data."""
+ """Represents a data collection session. Not all tasks record session data.
+
+ Args:
+ save_location: Location where session data will be saved.
+ symbol_set: List of possible symbols that can be presented.
+ task: Name of the task being performed.
+ mode: Mode of operation (e.g., 'RSVP').
+ decision_threshold: Threshold for making decisions.
+ task_data: Additional task-specific data.
+ """
def __init__(self,
save_location: str,
@@ -232,24 +289,40 @@ def __init__(self,
@property
def total_number_series(self) -> int:
- """Total number of series that contain sequences."""
+ """Total number of series that contain sequences.
+
+ Returns:
+ int: Number of non-empty series.
+ """
return len([lst for lst in self.series if lst])
@property
def total_number_decisions(self) -> int:
- """Total number of series that ended in a decision."""
+ """Total number of series that ended in a decision.
+
+ Returns:
+ int: Number of completed series.
+ """
# An alternate implementation would be to count the inquiries with
# decision_made property of true.
return len(self.series) - 1
@property
def total_inquiries(self) -> int:
- """Total number of inquiries presented."""
+ """Total number of inquiries presented.
+
+ Returns:
+ int: Total number of inquiries.
+ """
return sum([len(lst) for lst in self.series])
@property
def inquiries_per_selection(self) -> Optional[float]:
- """Inquiries per selection"""
+ """Inquiries per selection.
+
+ Returns:
+ Optional[float]: Average inquiries per selection, or None if no selections.
+ """
selections = self.total_number_decisions
if selections == 0:
return None
@@ -257,25 +330,32 @@ def inquiries_per_selection(self) -> Optional[float]:
@property
def all_inquiries(self) -> List[Inquiry]:
- """List of all Inquiries for the whole session"""
+ """List of all Inquiries for the whole session.
+
+ Returns:
+ List[Inquiry]: All inquiries from non-empty series.
+ """
return [inq for inquiries in self.series for inq in inquiries if inquiries]
def has_evidence(self) -> bool:
- """Tests whether any inquiries have evidence."""
+ """Tests whether any inquiries have evidence.
+
+ Returns:
+ bool: True if any inquiries have evidence.
+ """
return any(inq.evidences for inq in self.all_inquiries)
def add_series(self) -> None:
- """Add another series unless the last one is empty"""
+ """Add another series unless the last one is empty."""
if self.last_series():
self.series.append([])
def add_sequence(self, inquiry: Inquiry, new_series: bool = False) -> None:
- """Append sequence information
+ """Append sequence information.
- Parameters:
- -----------
- inquiry - data to append
- new_series - a True value indicates that this is the first stim of
+ Args:
+ inquiry: Data to append.
+ new_series: A True value indicates that this is the first stim of
a new series.
"""
if new_series:
@@ -283,23 +363,42 @@ def add_sequence(self, inquiry: Inquiry, new_series: bool = False) -> None:
self.last_series().append(inquiry)
def last_series(self) -> List[Inquiry]:
- """Returns the last series"""
+ """Returns the last series.
+
+ Returns:
+ List[Inquiry]: Last series of inquiries.
+ """
return self.series[-1]
def last_inquiry(self) -> Optional[Inquiry]:
- """Returns the last inquiry of the last series."""
+ """Returns the last inquiry of the last series.
+
+ Returns:
+ Optional[Inquiry]: Last inquiry if it exists.
+ """
series = self.last_series()
if series:
return series[-1]
return None
def latest_series_is_empty(self) -> bool:
- """Whether the latest series has had any inquiries added to it."""
+ """Whether the latest series has had any inquiries added to it.
+
+ Returns:
+ bool: True if latest series is empty.
+ """
return len(self.last_series()) == 0
def as_dict(self,
evidence_only: bool = False) -> Dict[str, Any]:
- """Dict representation"""
+ """Dict representation.
+
+ Args:
+ evidence_only: Whether to include only evidence-related data.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing session data.
+ """
series_dict: Dict[str, Any] = {}
for i, series in enumerate(self.series):
if series:
@@ -339,13 +438,14 @@ def as_dict(self,
return info
@classmethod
- def from_dict(cls, data: dict) -> 'Session':
+ def from_dict(cls, data: Dict[str, Any]) -> 'Session':
"""Deserialize from a dict.
- Parameters:
- ----------
- data - a dict in the format of the data output by the as_dict
- method.
+ Args:
+ data: A dict in the format of the data output by the as_dict method.
+
+ Returns:
+ Session: New session instance created from dict data.
"""
session = cls(save_location=data['session'],
task=data['task'],
diff --git a/bcipy/task/demo/actions/demo_calibration_report.py b/bcipy/task/demo/actions/demo_calibration_report.py
index 776b91aaa..9ca7b79f8 100644
--- a/bcipy/task/demo/actions/demo_calibration_report.py
+++ b/bcipy/task/demo/actions/demo_calibration_report.py
@@ -9,7 +9,8 @@
if __name__ == '__main__':
import argparse
- parser = argparse.ArgumentParser(description='Generate a calibration report from a session data file.')
+ parser = argparse.ArgumentParser(
+ description='Generate a calibration report from a session data file.')
# Add the arguments: parameters and protocol
parser.add_argument(
'--parameters',
@@ -31,6 +32,7 @@
The protocol path is the path to the directory containing the calibration sessions.
"""
- action = BciPyCalibrationReportAction(parameters=parameters, save_path='.', protocol_path=args.protocol)
+ action = BciPyCalibrationReportAction(
+ parameters=parameters, save_path='.', protocol_path=args.protocol)
print('Generating Report.')
task_data = action.execute()
diff --git a/bcipy/task/demo/orchestrator/demo_orchestrator.py b/bcipy/task/demo/orchestrator/demo_orchestrator.py
index afbdd5108..fdc6f777e 100644
--- a/bcipy/task/demo/orchestrator/demo_orchestrator.py
+++ b/bcipy/task/demo/orchestrator/demo_orchestrator.py
@@ -43,7 +43,8 @@ def demo_orchestrator(parameters_path: str) -> None:
import argparse
- parser = argparse.ArgumentParser(description="Demo the SessionOrchestrator")
+ parser = argparse.ArgumentParser(
+ description="Demo the SessionOrchestrator")
parser.add_argument(
'-p',
'--parameters_path',
diff --git a/bcipy/task/exceptions.py b/bcipy/task/exceptions.py
index 7de570c24..5e29cb573 100644
--- a/bcipy/task/exceptions.py
+++ b/bcipy/task/exceptions.py
@@ -1,33 +1,79 @@
+"""Task-specific exceptions for the BciPy task module.
+
+This module defines custom exceptions that can be raised during task execution,
+registration, and evidence evaluation.
+"""
+
from typing import Any, Optional
class InsufficientDataException(Exception):
- """Insufficient Data Exception.
+ """Exception raised when task data requirements are not met.
+
+ This exception is raised when a task does not have sufficient data to
+ execute properly, such as missing calibration data or required parameters.
+
+ Args:
+ message: Description of what data was insufficient.
+ errors: Optional additional error information.
- Thrown when data requirements to execute task are violated.
+ Attributes:
+ message: The error message.
+ errors: Additional error details, if any.
"""
def __init__(self, message: str, errors: Optional[Any] = None) -> None:
super().__init__(message)
+ self.message = message
self.errors = errors
class TaskRegistryException(Exception):
- """Task Registry Exception.
+ """Exception raised when there are issues with task registration.
- Thrown when task type is unregistered.
+ This exception is raised when attempting to use an unregistered task type
+ or when there are problems with the task registry.
+
+ Args:
+ message: Description of the registration issue.
+ errors: Optional additional error information.
+
+ Attributes:
+ message: The error message.
+ errors: Additional error details, if any.
"""
def __init__(self, message: str, errors: Optional[Any] = None) -> None:
super().__init__(message)
+ self.message = message
self.errors = errors
class MissingEvidenceEvaluator(Exception):
- """Thrown when an evidence evaluator can't be found that matches the
- provided data content type input and evidence_type output"""
+ """Exception raised when a required evidence evaluator is not found.
+
+ This exception is raised when no evidence evaluator can be found that matches
+ the provided data content type input and evidence_type output requirements.
+
+ Args:
+ message: Description of the missing evaluator.
+ """
+
+ def __init__(self, message: str) -> None:
+ super().__init__(message)
+ self.message = message
class DuplicateModelEvidence(Exception):
- """Thrown from a task when more than one of the provided models produces
- the same type of evidence"""
+ """Exception raised when multiple models produce the same evidence type.
+
+ This exception is raised when more than one of the provided models produces
+ the same type of evidence, making it ambiguous which evidence should be used.
+
+ Args:
+ message: Description of the duplicate evidence.
+ """
+
+ def __init__(self, message: str) -> None:
+ super().__init__(message)
+ self.message = message
diff --git a/bcipy/task/main.py b/bcipy/task/main.py
index daadb089d..0e3c51809 100644
--- a/bcipy/task/main.py
+++ b/bcipy/task/main.py
@@ -1,7 +1,13 @@
+"""Core task module defining base classes for BciPy tasks.
+
+This module provides the foundational classes for implementing BCI tasks,
+including the abstract base Task class and supporting data structures.
+"""
+
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
-from typing import Optional
+from typing import Any, Dict, Optional
from bcipy.config import STATIC_AUDIO_PATH
from bcipy.core.parameters import Parameters
@@ -9,16 +15,33 @@
@dataclass
-class TaskData():
- """TaskData.
+class TaskData:
+ """Data structure for storing task execution results.
+
+ This class encapsulates the data returned from a task execution, including
+ the save path for any generated data and a dictionary of task-specific data.
- Data structure for storing task return data.
+ Attributes:
+ save_path: Path where task data was saved.
+ task_dict: Dictionary containing task-specific data and results.
"""
save_path: Optional[str] = None
- task_dict: Optional[dict] = None
+ task_dict: Optional[Dict[str, Any]] = None
class TaskMode(Enum):
+ """Enumeration of supported BCI task modes.
+
+ This enum defines the different types of tasks that can be executed in the BCI system.
+ Each mode represents a specific type of interaction or experiment.
+
+ Attributes:
+ CALIBRATION: Mode for system calibration tasks.
+ COPYPHRASE: Mode for copy-spelling tasks.
+ TIMING_VERIFICATION: Mode for timing verification tasks.
+ ACTION: Mode for action-based tasks.
+ TRAINING: Mode for training tasks.
+ """
CALIBRATION = "calibration"
COPYPHRASE = "copy phrase"
TIMING_VERIFICATION = "timing verification"
@@ -26,16 +49,37 @@ class TaskMode(Enum):
TRAINING = "training"
def __str__(self) -> str:
+ """Return the string value of the task mode.
+
+ Returns:
+ str: The string representation of the task mode.
+ """
return self.value
def __repr__(self) -> str:
+ """Return the string representation of the task mode.
+
+ Returns:
+ str: The string representation of the task mode.
+ """
return self.value
class Task(ABC):
- """Task.
+ """Abstract base class for BciPy tasks.
+
+ This class defines the interface that all BCI tasks must implement. It provides
+ the basic structure for task execution, setup, and cleanup.
+
+ Attributes:
+ name: Name of the task.
+ mode: Mode of operation for the task.
+ parameters: Task configuration parameters.
+ data_save_location: Location where task data should be saved.
- Base class for BciPy tasks.
+ Note:
+ Subclasses must define the 'name' and 'mode' class attributes and
+ implement the execute() method.
"""
name: str
mode: TaskMode
@@ -43,19 +87,62 @@ class Task(ABC):
data_save_location: str
def __init__(self, *args, **kwargs) -> None:
+ """Initialize the task.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Raises:
+ AssertionError: If name or mode attributes are not defined.
+ """
super(Task, self).__init__()
- assert getattr(self, 'name', None) is not None, "Task must have a `name` attribute defined"
- assert getattr(self, 'mode', None) is not None, "Task must have a `mode` attribute defined"
+ assert getattr(
+ self, 'name', None) is not None, "Task must have a `name` attribute defined"
+ assert getattr(
+ self, 'mode', None) is not None, "Task must have a `mode` attribute defined"
@abstractmethod
def execute(self) -> TaskData:
+ """Execute the task.
+
+ This method must be implemented by all task subclasses to define the
+ task's execution logic.
+
+ Returns:
+ TaskData: Object containing the results of the task execution.
+ """
...
- def setup(self, *args, **kwargs):
+ def setup(self, *args, **kwargs) -> None:
+ """Set up the task before execution.
+
+ This method can be overridden by subclasses to perform any necessary
+ setup before task execution.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+ """
...
- def cleanup(self, *args, **kwargs):
+ def cleanup(self, *args, **kwargs) -> None:
+ """Clean up after task execution.
+
+ This method can be overridden by subclasses to perform any necessary
+ cleanup after task execution.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+ """
...
- def alert(self):
- play_sound(f"{STATIC_AUDIO_PATH}/{self.parameters['alert_sound_file']}")
+ def alert(self) -> None:
+ """Play an alert sound.
+
+ Plays the configured alert sound file to notify the user.
+ The sound file is specified in the task parameters.
+ """
+ play_sound(
+ f"{STATIC_AUDIO_PATH}/{self.parameters['alert_sound_file']}")
diff --git a/bcipy/task/orchestrator/__init__.py b/bcipy/task/orchestrator/__init__.py
index df0973e4c..b594a7eb2 100644
--- a/bcipy/task/orchestrator/__init__.py
+++ b/bcipy/task/orchestrator/__init__.py
@@ -1,3 +1,10 @@
+"""Task orchestration module for managing BCI experiment sessions.
+
+This module provides functionality for managing and executing sequences of BCI tasks,
+including task initialization, execution order, data saving, and logging. The main
+component is the SessionOrchestrator, which handles the lifecycle of task execution.
+"""
+
from bcipy.task.orchestrator.orchestrator import SessionOrchestrator
__all__ = ['SessionOrchestrator']
diff --git a/bcipy/task/orchestrator/orchestrator.py b/bcipy/task/orchestrator/orchestrator.py
index 9fcf2e92a..3090c0e9f 100644
--- a/bcipy/task/orchestrator/orchestrator.py
+++ b/bcipy/task/orchestrator/orchestrator.py
@@ -1,3 +1,9 @@
+"""Task orchestration module for managing BCI experiment sessions.
+
+This module provides functionality for managing and executing sequences of BCI tasks,
+handling task initialization, execution, logging, and data management.
+"""
+
# mypy: disable-error-code="arg-type, assignment"
import errno
import json
@@ -8,7 +14,7 @@
import time
from datetime import datetime
from logging import Logger
-from typing import List, Optional, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
from bcipy.config import (DEFAULT_EXPERIMENT_ID, DEFAULT_PARAMETERS_FILENAME,
DEFAULT_PARAMETERS_PATH, DEFAULT_USER_ID,
@@ -21,43 +27,66 @@
class SessionOrchestrator:
+ """Manages the execution of a protocol of BCI tasks.
+
+ The Session Orchestrator is responsible for managing the execution of a sequence
+ of tasks within an experiment session. It handles task initialization, execution
+ order, data saving, and logging.
+
+ Attributes:
+ tasks: List of task classes to execute.
+ task_names: List of task names in execution order.
+ parameters: Configuration parameters for the session.
+ sys_info: System information dictionary.
+ log: Session logger instance.
+ save_folder: Path where session data is saved.
+ session_data: List of data from executed tasks.
+ ready_to_execute: Whether tasks are ready to execute.
+ last_task_dir: Path to the last executed task's directory.
+ copyphrases: List of phrases for copy tasks.
+ next_phrase: Next phrase to be used in copy tasks.
+ starting_index: Starting index for copy tasks.
+ user: User identifier.
+ fake: Whether to use fake data.
+ experiment_id: Experiment identifier.
+ alert: Whether to alert when tasks complete.
+ visualize: Whether to visualize task results.
+ progress: Current task execution progress.
+ user_exit: Whether user has requested to exit.
"""
- Session Orchestrator
- --------------------
-
- The Session Orchestrator is responsible for managing the execution of a protocol of tasks. It is initialized with an
- experiment ID, user ID, and parameters file. Tasks are added to the orchestrator, which are then executed in order.
- """
- tasks: List[Type[Task]]
- task_names: List[str]
- parameters: Parameters
- sys_info: dict
- log: Logger
- save_folder: str
- session_data: List[TaskData]
- ready_to_execute: bool = False
- last_task_dir: Optional[str] = None
def __init__(
self,
experiment_id: str = DEFAULT_EXPERIMENT_ID,
user: str = DEFAULT_USER_ID,
parameters_path: str = DEFAULT_PARAMETERS_PATH,
- parameters: Parameters = None,
+ parameters: Optional[Parameters] = None,
fake: bool = False,
alert: bool = False,
visualize: bool = False
) -> None:
+ """Initialize the session orchestrator.
+
+ Args:
+ experiment_id: Identifier for the experiment session.
+ user: User identifier.
+ parameters_path: Path to parameters file.
+ parameters: Optional pre-loaded parameters object.
+ fake: Whether to use fake data for testing.
+ alert: Whether to alert when tasks complete.
+ visualize: Whether to visualize task results.
+ """
self.parameters_path = parameters_path
if not parameters:
- self.parameters = load_json_parameters(parameters_path, value_cast=True)
+ self.parameters = load_json_parameters(
+ parameters_path, value_cast=True)
else:
# This allows for the parameters to be passed in directly and modified before executions
self.parameters = parameters
- self.copyphrases = None
- self.next_phrase = None
- self.starting_index = 0
+ self.copyphrases: Optional[List[Tuple[str, int]]] = None
+ self.next_phrase: Optional[str] = None
+ self.starting_index: int = 0
self.initialize_copy_phrases()
@@ -65,37 +94,49 @@ def __init__(
self.fake = fake
self.experiment_id = experiment_id
self.sys_info = self.get_system_info()
- self.tasks = []
- self.task_names = []
- self.session_data = []
- self.save_folder = self._init_orchestrator_save_folder(self.parameters["data_save_loc"])
+ self.tasks: List[Type[Task]] = []
+ self.task_names: List[str] = []
+ self.session_data: List[TaskData] = []
+ self.save_folder = self._init_orchestrator_save_folder(
+ self.parameters["data_save_loc"])
self.logger = self._init_orchestrator_logger(self.save_folder)
self.alert = alert
- self.logger.info("Alerts are on") if self.alert else self.logger.info("Alerts are off")
+ self.logger.info("Alerts are on") if self.alert else self.logger.info(
+ "Alerts are off")
self.visualize = visualize
- self.progress = 0
+ self.progress: int = 0
self.ready_to_execute = False
self.user_exit = False
+ self.last_task_dir = None
self.logger.info("Session Orchestrator initialized successfully")
def add_task(self, task: Type[Task]) -> None:
- """Add a task to the orchestrator"""
+ """Add a single task to the execution queue.
+
+ Args:
+ task: Task class to add to the queue.
+ """
self.tasks.append(task)
self.task_names.append(task.name)
self.ready_to_execute = True
def add_tasks(self, tasks: List[Type[Task]]) -> None:
- """Add a list of tasks to the orchestrator"""
+ """Add multiple tasks to the execution queue.
+
+ Args:
+ tasks: List of task classes to add to the queue.
+ """
for task in tasks:
self.add_task(task)
self.ready_to_execute = True
def set_next_phrase(self) -> None:
- """Set the next phrase to be copied from the list of copy phrases loaded or the parameters directly.
+ """Set the next phrase for copy phrase tasks.
- If there are no more phrases to copy, the task text and spelled letters from parameters will be used.
+ If there are phrases in the copyphrases list, uses the next one.
+ Otherwise, uses the task_text from parameters.
"""
if self.copyphrases:
if len(self.copyphrases) > 0:
@@ -108,9 +149,9 @@ def set_next_phrase(self) -> None:
self.parameters['spelled_letters_count'] = self.starting_index
def initialize_copy_phrases(self) -> None:
- """Load copy phrases from a json file or take the task text if no file is provided.
+ """Load copy phrases from a JSON file.
- Expects a json file structured as follows:
+ The JSON file should be structured as:
{
"Phrases": [
[string, int],
@@ -118,8 +159,9 @@ def initialize_copy_phrases(self) -> None:
...
]
}
+
+ If no file is provided, uses task_text from parameters.
"""
- # load copy phrases from json file or take the task text if no file is provided
if self.parameters.get('copy_phrases_location'):
with open(self.parameters['copy_phrases_location'], 'r') as f:
copy_phrases = json.load(f)
@@ -132,21 +174,26 @@ def initialize_copy_phrases(self) -> None:
self.starting_index = self.parameters['spelled_letters_count']
def execute(self) -> None:
- """Executes queued tasks in order"""
+ """Execute all queued tasks in order.
+ Raises:
+ Exception: If no tasks have been added to the queue.
+ """
if not self.ready_to_execute:
msg = "Orchestrator not ready to execute. No tasks have been added."
- self.log.error(msg)
+ self.logger.error(msg)
raise Exception(msg)
- self.logger.info(f"Session Orchestrator executing tasks in order: {self.task_names}")
+ self.logger.info(
+ f"Session Orchestrator executing tasks in order: {self.task_names}")
for task in self.tasks:
self.progress += 1
if task.mode == TaskMode.COPYPHRASE:
self.set_next_phrase()
try:
# initialize the task save folder and logger
- self.logger.info(f"Initializing task {self.progress}/{len(self.tasks)} {task.name}")
+ self.logger.info(
+ f"Initializing task {self.progress}/{len(self.tasks)} {task.name}")
data_save_location = self._init_task_save_folder(task)
self._init_task_logger(data_save_location)
@@ -179,13 +226,15 @@ def execute(self) -> None:
if self.visualize:
# Visualize session data and fail silently if it errors
try:
- self.logger.info(f"Visualizing session data. Saving to {data_save_location}")
+ self.logger.info(
+ f"Visualizing session data. Saving to {data_save_location}")
subprocess.run(
f'bcipy-erp-viz -s "{data_save_location}" '
f'--parameters "{self.parameters_path}" --show --save',
shell=True)
except Exception as e:
- self.logger.info(f'Error visualizing session data: {e}')
+ self.logger.info(
+ f'Error visualizing session data: {e}')
initialized_task = None
@@ -193,7 +242,7 @@ def execute(self) -> None:
self.logger.error(f"Task {task.name} failed to execute")
self.logger.exception(e)
try:
- initialized_task.cleanup()
+ initialized_task.cleanup() # type: ignore
except BaseException:
pass
@@ -208,25 +257,52 @@ def execute(self) -> None:
self.progress = 0
def _init_orchestrator_logger(self, save_folder: str) -> Logger:
+ """Initialize the session logger.
+
+ Args:
+ save_folder: Directory to save log files.
+
+ Returns:
+ Logger: Configured logger instance.
+ """
return configure_logger(
save_folder,
PROTOCOL_LOG_FILENAME,
logging.DEBUG)
def _init_orchestrator_save_folder(self, save_path: str) -> str:
+ """Initialize the session save directory.
+
+ Args:
+ save_path: Base path for saving session data.
+
+ Returns:
+ str: Path to the created save directory.
+ """
date_time = datetime.now()
date = date_time.strftime("%Y-%m-%d")
timestamp = date_time.strftime("%Y-%m-%d_%H-%M-%S")
- # * No '/' after `save_folder` since it is included in
- # * `data_save_location` in parameters
path = f'{save_path}{self.user}/{date}/{self.experiment_id}/{timestamp}/'
os.makedirs(path)
os.makedirs(os.path.join(path, 'logs'), exist_ok=True)
return path
def _init_task_save_folder(self, task: Type[Task]) -> str:
+ """Initialize a save directory for a task.
+
+ Args:
+ task: Task class to create directory for.
+
+ Returns:
+ str: Path to the created task directory.
+
+ Raises:
+ OSError: If directory creation fails for reasons other than
+ the directory already existing.
+ """
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
- save_directory = self.save_folder + f'{task.name.replace(" ", "_")}_{timestamp}/'
+ save_directory = self.save_folder + \
+ f'{task.name.replace(" ", "_")}_{timestamp}/'
try:
# make a directory to save task data to
os.makedirs(save_directory)
@@ -244,49 +320,68 @@ def _init_task_save_folder(self, task: Type[Task]) -> str:
"type": "str",
}
)
- self.parameters.save(save_directory, name=DEFAULT_PARAMETERS_FILENAME)
+ self.parameters.save(
+ save_directory, name=DEFAULT_PARAMETERS_FILENAME)
except OSError as error:
# If the error is anything other than file existing, raise an error
if error.errno != errno.EEXIST:
raise error
+
return save_directory
def _init_task_logger(self, save_folder: str) -> None:
- configure_logger(
- save_folder,
- SESSION_LOG_FILENAME,
- logging.DEBUG)
+ """Initialize a logger for a task.
+
+ Args:
+ save_folder: Directory to save task logs.
+ """
+ configure_logger(save_folder, SESSION_LOG_FILENAME, logging.DEBUG)
def _save_data(self) -> None:
+ """Save all session data.
- self._save_procotol_data()
- # Save the remaining phrase data to a json file to be used in the next session
- if self.copyphrases and len(self.copyphrases) > 0:
- self._save_copy_phrases()
-
- def _save_procotol_data(self) -> None:
- # Save the protocol data to a json file
- with open(f'{self.save_folder}/{PROTOCOL_FILENAME}', 'w') as f:
- f.write(json.dumps({
- 'tasks': self.task_names,
- 'parameters': self.parameters_path,
- 'system_info': self.sys_info,
- }))
- self.logger.info("Protocol data successfully saved")
+ Saves protocol data and copy phrases data to their respective files.
+ """
+ self._save_protocol_data()
+ self._save_copy_phrases()
+
+ def _save_protocol_data(self) -> None:
+ """Save protocol data to a JSON file.
+
+ Saves task names, system info, and other session metadata.
+ """
+ data = {
+ 'tasks': self.task_names,
+ 'sys_info': self.sys_info,
+ 'parameters': self.parameters_path,
+ 'user': self.user,
+ 'experiment_id': self.experiment_id,
+ 'fake': self.fake
+ }
+ with open(os.path.join(self.save_folder, PROTOCOL_FILENAME), 'w') as f:
+ json.dump(data, f)
def _save_copy_phrases(self) -> None:
- # Save the copy phrases data to a json file
- with open(f'{self.save_folder}/{MULTIPHRASE_FILENAME}', 'w') as f:
- f.write(json.dumps({
- 'Phrases': self.copyphrases
- }))
- self.logger.info("Copy phrases data successfully saved")
-
- def get_system_info(self) -> dict:
+ """Save copy phrases data to a JSON file.
+
+ Only saves if copy phrases were used in the session.
+ """
+ if self.copyphrases:
+ with open(os.path.join(self.save_folder, MULTIPHRASE_FILENAME), 'w') as f:
+ json.dump({'Phrases': self.copyphrases}, f)
+
+ def get_system_info(self) -> Dict[str, Any]:
+ """Get system information.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing system information.
+ """
return get_system_info()
- def close_experiment_callback(self):
- """Callback to close the experiment."""
- self.logger.info("User has exited the experiment.")
+ def close_experiment_callback(self) -> None:
+ """Callback for handling user-initiated experiment closure.
+
+ Sets the user_exit flag to true to stop task execution.
+ """
self.user_exit = True
diff --git a/bcipy/task/orchestrator/protocol.py b/bcipy/task/orchestrator/protocol.py
index 20396c07d..30f680492 100644
--- a/bcipy/task/orchestrator/protocol.py
+++ b/bcipy/task/orchestrator/protocol.py
@@ -1,77 +1,68 @@
-"""This file can define actions that can happen in a session orchestrator visit.
-To start these will be 1:1 with tasks, but later this can be extended to represent training sequences, GUI popups etc"""
+"""Protocol handling module for BciPy task orchestration.
+
+This module provides functionality for parsing and managing task protocols,
+which define sequences of actions to be executed in a session. While currently
+focused on task sequences, this can be extended to support training sequences,
+GUI interactions, and other orchestrated behaviors.
+"""
from typing import List, Type
-from bcipy.config import TASK_SEPERATOR
+from bcipy.config import TASK_SEPARATOR
from bcipy.task import Task
from bcipy.task.registry import TaskRegistry
def parse_protocol(protocol: str) -> List[Type[Task]]:
- """
- Parses a string of actions into a list of Task objects.
-
- Converts a string of actions into a list of Task objects. The string is expected
- to be in the format of 'Action1 -> Action2 -> ... -> ActionN'.
- Parameters
- ----------
- protocol : str
- A string of actions in the format of 'Action1 -> Action2 -> ... -> ActionN'.
-
- Returns
- -------
- List[TaskType]
- A list of TaskType objects that represent the actions in the input string.
+ """Parse a protocol string into a list of Task classes.
+
+ Converts a string of task names into a list of Task classes. The string
+ should be in the format 'Task1 -> Task2 -> ... -> TaskN', where each
+ task name corresponds to a registered task in the TaskRegistry.
+
+ Args:
+ protocol: String of task names separated by the task separator.
+ Format: 'Task1 -> Task2 -> ... -> TaskN'
+
+ Returns:
+ List[Type[Task]]: List of Task classes corresponding to the protocol.
+
+ Raises:
+ ValueError: If any task name in the protocol is not registered.
"""
task_registry = TaskRegistry()
- return [task_registry.get(item.strip()) for item in protocol.split(TASK_SEPERATOR)]
+ return [task_registry.get(item.strip()) for item in protocol.split(TASK_SEPARATOR)]
def validate_protocol_string(protocol: str) -> None:
- """
- Validates a string of actions.
+ """Validate a protocol string against registered tasks.
- Validates a string of actions. The string is expected to be in the format of 'Action1 -> Action2 -> ... -> ActionN'.
+ Checks that all task names in the protocol string correspond to
+ registered tasks in the TaskRegistry.
- Parameters
- ----------
- protocol : str
- A string of actions in the format of 'Action1 -> Action2 -> ... -> ActionN'.
+ Args:
+ protocol: String of task names separated by the task separator.
+ Format: 'Task1 -> Task2 -> ... -> TaskN'
- Raises
- ------
- ValueError
- If the string of actions is invalid.
+ Raises:
+ ValueError: If any task name in the protocol is not registered.
"""
- for protocol_item in protocol.split(TASK_SEPERATOR):
+ for protocol_item in protocol.split(TASK_SEPARATOR):
if protocol_item.strip() not in TaskRegistry().list():
- raise ValueError(f"Invalid task '{protocol_item}' name in protocol string.")
+ raise ValueError(
+ f"Invalid task '{protocol_item}' name in protocol string.")
-def serialize_protocol(protocol: List[Type[Task]]) -> str:
- """
- Converts a list of TaskType objects into a string of actions.
+def serialize_protocol(tasks: List[Type[Task]]) -> str:
+ """Convert a list of Task classes into a protocol string.
- Converts a list of TaskType objects into a string of actions. The string is in the format of
- 'Action1 -> Action2 -> ... -> ActionN'.
+ Creates a protocol string from a list of Task classes, using the task
+ separator to join task names.
- Parameters
- ----------
- protocol : str
- A string of actions in the format of 'Action1 -> Action2 -> ... -> ActionN'.
+ Args:
+ tasks: List of Task classes to serialize.
- Returns
- -------
- List[TaskType]
- A list of TaskType objects that represent the actions in the input string.
+ Returns:
+ str: Protocol string in format 'Task1 -> Task2 -> ... -> TaskN'.
"""
-
- return f" {TASK_SEPERATOR} ".join([item.name for item in protocol])
-
-
-if __name__ == '__main__':
- actions = parse_protocol("Matrix Calibration -> Matrix Copy Phrase")
- string = serialize_protocol(actions)
- print(actions)
- print(string)
+ return f" {TASK_SEPARATOR} ".join([item.name for item in tasks])
diff --git a/bcipy/task/paradigm/matrix/calibration.py b/bcipy/task/paradigm/matrix/calibration.py
index 1ecc355be..851bc7942 100644
--- a/bcipy/task/paradigm/matrix/calibration.py
+++ b/bcipy/task/paradigm/matrix/calibration.py
@@ -1,3 +1,10 @@
+"""Matrix calibration task module.
+
+This module provides the Matrix calibration task implementation which performs
+Matrix stimulus inquiries to elicit ERPs. The task presents a matrix of stimuli
+and highlights them according to configured parameters.
+"""
+
from typing import Any, Dict, List, Optional
from psychopy import visual
@@ -16,29 +23,41 @@
class MatrixCalibrationTask(BaseCalibrationTask):
"""Matrix Calibration Task.
- Calibration task performs an Matrix stimulus inquiry
- to elicit an ERP. Parameters change the number of stimuli
- (i.e. the subset of matrix) and for how long they will highlight.
- Parameters also change color and text / image inputs.
-
- A task begins setting up variables --> initializing eeg -->
- awaiting user input to start -->
- setting up stimuli --> highlighting inquiries -->
- saving data
-
- PARAMETERS:
- ----------
- parameters (dict)
- file_save (str)
- fake (bool)
-
+ This task performs Matrix stimulus inquiries to elicit ERPs by highlighting
+ elements in a matrix display. Parameters control the number of stimuli,
+ highlight duration, colors, and text/image inputs.
+
+ Task flow:
+ 1. Setup variables
+ 2. Initialize EEG
+ 3. Await user input
+ 4. Setup stimuli
+ 5. Perform highlighting inquiries
+ 6. Save data
+
+ Attributes:
+ name: Name of the task.
+ paradigm: Name of the paradigm.
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
+ window: PsychoPy window for display.
+ experiment_clock: Task timing clock.
+ display: Matrix display instance.
+ symbol_set: Set of symbols to display.
"""
+
name = 'Matrix Calibration'
paradigm = 'Matrix'
@property
def screen_info(self) -> Dict[str, Any]:
- """Screen properties"""
+ """Get screen properties.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing screen size, refresh rate,
+ and units information.
+ """
return {
'screen_size_pixels': self.window.size.tolist(),
'screen_hz': get_screen_info().rate,
@@ -46,22 +65,45 @@ def screen_info(self) -> Dict[str, Any]:
}
def init_display(self) -> MatrixDisplay:
- """Initialize the display"""
+ """Initialize the matrix display.
+
+ Returns:
+ MatrixDisplay: Configured matrix display instance.
+ """
return init_matrix_display(self.parameters, self.window,
self.experiment_clock, self.symbol_set)
def exit_display(self) -> None:
+ """Clean up display resources and save screenshot.
+
+ Raises:
+ AssertionError: If display is not a MatrixDisplay instance.
+ """
assert isinstance(self.display, MatrixDisplay)
self.display.capture_grid_screenshot(self.file_save)
return super().exit_display()
def cleanup(self) -> None:
+ """Perform cleanup operations and save stimuli position data.
+
+ Raises:
+ AssertionError: If display is not a MatrixDisplay instance.
+ """
assert isinstance(self.display, MatrixDisplay)
save_stimuli_position_info(self.display.stim_positions, self.file_save,
self.screen_info)
return super().cleanup()
def session_task_data(self) -> Optional[Dict[str, Any]]:
+ """Get session task data.
+
+ Returns:
+ Optional[Dict[str, Any]]: Dictionary containing stimuli positions
+ and screen information.
+
+ Raises:
+ AssertionError: If display is not a MatrixDisplay instance.
+ """
assert isinstance(self.display, MatrixDisplay)
return {**self.display.stim_positions, **self.screen_info}
@@ -69,7 +111,17 @@ def session_task_data(self) -> Optional[Dict[str, Any]]:
def init_matrix_display(parameters: Parameters, window: visual.Window,
experiment_clock: Clock,
symbol_set: List[str]) -> MatrixDisplay:
- """Initialize the matrix display"""
+ """Initialize the matrix display with given parameters.
+
+ Args:
+ parameters: Task configuration parameters.
+ window: PsychoPy window for display.
+ experiment_clock: Task timing clock.
+ symbol_set: Set of symbols to display.
+
+ Returns:
+ MatrixDisplay: Configured matrix display instance.
+ """
info = InformationProperties(
info_color=[parameters['info_color']],
info_pos=[(parameters['info_pos_x'], parameters['info_pos_y'])],
@@ -78,7 +130,8 @@ def init_matrix_display(parameters: Parameters, window: visual.Window,
info_text=[parameters['info_text']],
)
stimuli = StimuliProperties(stim_font=parameters['font'],
- stim_pos=(parameters['matrix_stim_pos_x'], parameters['matrix_stim_pos_y']),
+ stim_pos=(
+ parameters['matrix_stim_pos_x'], parameters['matrix_stim_pos_y']),
stim_height=parameters['matrix_stim_height'],
stim_inquiry=[''] * parameters['stim_length'],
stim_colors=[parameters['stim_color']] *
diff --git a/bcipy/task/paradigm/matrix/copy_phrase.py b/bcipy/task/paradigm/matrix/copy_phrase.py
index fdde27325..fe993e813 100644
--- a/bcipy/task/paradigm/matrix/copy_phrase.py
+++ b/bcipy/task/paradigm/matrix/copy_phrase.py
@@ -1,4 +1,9 @@
-"""Defines the Copy Phrase Task which uses a Matrix display"""
+"""Matrix copy phrase task module.
+
+This module defines the Copy Phrase Task implementation using a Matrix display.
+The task allows users to copy a predefined phrase using a matrix-based interface.
+"""
+
from psychopy import visual
from bcipy.core.parameters import Parameters
@@ -14,22 +19,23 @@
class MatrixCopyPhraseTask(RSVPCopyPhraseTask):
"""Matrix Copy Phrase Task.
- Initializes and runs all needed code for executing a copy phrase task. A
- phrase is set in parameters and necessary objects (daq, display) are
- passed to this function.
-
- Parameters
- ----------
- parameters : dict,
- configuration details regarding the experiment. See parameters.json
- file_save : str,
- path location of where to save data from the session
- fake : boolean, optional
- boolean to indicate whether this is a fake session or not.
- Returns
- -------
- TaskData
+ This task allows users to copy a predefined phrase using a matrix-based
+ interface. The task initializes and runs all necessary components for
+ executing a copy phrase task.
+
+ Attributes:
+ name: Name of the task.
+ paradigm: Name of the paradigm.
+ mode: Task execution mode.
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
+ window: PsychoPy window for display.
+ experiment_clock: Task timing clock.
+ spelled_text: Currently spelled text.
+ PARAMETERS_USED: List of parameter names used by this task.
"""
+
name = 'Matrix Copy Phrase'
paradigm = 'Matrix'
mode = TaskMode.COPYPHRASE
@@ -100,7 +106,11 @@ class MatrixCopyPhraseTask(RSVPCopyPhraseTask):
]
def init_display(self) -> MatrixDisplay:
- """Initialize the Matrix display"""
+ """Initialize the Matrix display.
+
+ Returns:
+ MatrixDisplay: Configured matrix display instance.
+ """
return init_display(self.parameters, self.window,
self.experiment_clock, self.spelled_text)
@@ -110,8 +120,17 @@ def init_display(
win: visual.Window,
experiment_clock: Clock,
starting_spelled_text: str) -> MatrixDisplay:
- """Constructs a new Matrix display"""
+ """Initialize a new Matrix display with given parameters.
+
+ Args:
+ parameters: Task configuration parameters.
+ win: PsychoPy window for display.
+ experiment_clock: Task timing clock.
+ starting_spelled_text: Initial text to display.
+ Returns:
+ MatrixDisplay: Configured matrix display instance.
+ """
info = InformationProperties(
info_color=[parameters['info_color']],
info_pos=[(parameters['info_pos_x'], parameters['info_pos_y'])],
diff --git a/bcipy/task/paradigm/matrix/timing_verification.py b/bcipy/task/paradigm/matrix/timing_verification.py
index 05f8ac0e9..bd225ece6 100644
--- a/bcipy/task/paradigm/matrix/timing_verification.py
+++ b/bcipy/task/paradigm/matrix/timing_verification.py
@@ -1,3 +1,10 @@
+"""Matrix timing verification module.
+
+This module provides functionality for verifying display timing in Matrix tasks
+using photodiode stimuli. It alternates between solid and empty boxes that can
+be measured with a photodiode to ensure accurate stimulus presentation timing.
+"""
+
from itertools import cycle, islice, repeat
from typing import Iterator, List
@@ -11,32 +18,49 @@
class MatrixTimingVerificationCalibration(MatrixCalibrationTask):
"""Matrix Timing Verification Task.
- This task is used for verifying display timing by alternating solid and empty boxes. These
- stimuli can be used with a photodiode to ensure accurate presentations.
-
- Input:
- parameters (Dictionary)
- file_save (String)
- fake (Boolean)
+ This task verifies display timing by alternating solid and empty boxes.
+ The stimuli can be measured with a photodiode to ensure accurate
+ presentation timing.
- Output:
- TaskData
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
"""
+
name = 'Matrix Timing Verification'
mode = TaskMode.TIMING_VERIFICATION
def init_display(self) -> MatrixDisplay:
- """Initialize the display"""
+ """Initialize the display with transparent background.
+
+ Returns:
+ MatrixDisplay: Configured matrix display instance.
+ """
display = super().init_display()
display.start_opacity = 0.0
return display
@property
def symbol_set(self) -> List[str]:
- """Symbols used in the calibration"""
+ """Get symbols used in the calibration.
+
+ Returns:
+ List[str]: List of symbols with photodiode stimuli inserted.
+ """
return symbols_with_photodiode_stim(super().symbol_set)
def init_inquiry_generator(self) -> Iterator[Inquiry]:
+ """Initialize the inquiry generator for timing verification.
+
+ The generator alternates between solid and empty boxes with specified
+ timing parameters. A fixation point is shown between stimuli.
+
+ Returns:
+ Iterator[Inquiry]: Generator yielding timing verification inquiries.
+ """
params = self.parameters
# alternate between solid and empty boxes
@@ -63,11 +87,27 @@ def init_inquiry_generator(self) -> Iterator[Inquiry]:
def symbols_with_photodiode_stim(symbols: List[str]) -> List[str]:
- """Stim symbols with the central letters swapped out for Photodiode stim.
+ """Replace central symbols with photodiode stimuli.
+
+ Args:
+ symbols: List of symbols to modify.
+
+ Returns:
+ List[str]: Modified list with central symbols replaced by photodiode
+ stimuli.
"""
mid = int(len(symbols) / 2)
- def sym_at_index(sym, index) -> str:
+ def sym_at_index(sym: str, index: int) -> str:
+ """Get symbol at given index, replacing central indices with stimuli.
+
+ Args:
+ sym: Original symbol.
+ index: Position in symbol list.
+
+ Returns:
+ str: Original symbol or photodiode stimulus.
+ """
if index == mid:
return PhotoDiodeStimuli.SOLID.value
if index == mid + 1:
diff --git a/bcipy/task/paradigm/rsvp/calibration/calibration.py b/bcipy/task/paradigm/rsvp/calibration/calibration.py
index efd5ecb54..b00024c41 100644
--- a/bcipy/task/paradigm/rsvp/calibration/calibration.py
+++ b/bcipy/task/paradigm/rsvp/calibration/calibration.py
@@ -1,3 +1,9 @@
+"""RSVP calibration task module.
+
+This module provides the RSVP (Rapid Serial Visual Presentation) calibration task
+implementation which performs stimulus inquiries to elicit ERPs. The task presents
+stimuli in rapid succession with configurable timing and appearance parameters.
+"""
from psychopy import core, visual
from bcipy.core.parameters import Parameters
@@ -12,36 +18,59 @@
class RSVPCalibrationTask(BaseCalibrationTask):
"""RSVP Calibration Task.
- Calibration task performs an RSVP stimulus inquiry
- to elicit an ERP. Parameters will change how many stimuli
- and for how long they present. Parameters also change
- color and text / image inputs.
-
- This task progresses as follows:
+ This task performs RSVP stimulus inquiries to elicit ERPs by presenting
+ stimuli in rapid succession. Parameters control the number of stimuli,
+ presentation duration, colors, and text/image inputs.
- setting up variables --> initializing eeg --> awaiting user input to start --> setting up stimuli -->
- presenting inquiries --> saving data
+ Task flow:
+ 1. Setup variables
+ 2. Initialize EEG
+ 3. Await user input
+ 4. Setup stimuli
+ 5. Present inquiries
+ 6. Save data
- PARAMETERS:
- ----------
- parameters (dict)
- file_save (str)
- fake (bool)
+ Attributes:
+ name (str): Name of the task.
+ paradigm (str): Name of the paradigm.
+ parameters (Parameters): Task configuration parameters.
+ file_save (str): Path for saving task data.
+ fake (bool): Whether to run in fake (testing) mode.
+ window (visual.Window): PsychoPy window for display.
+ static_clock (core.StaticPeriod): Clock for static timing.
+ experiment_clock (Clock): Clock for experiment timing.
"""
- name = 'RSVP Calibration'
- paradigm = 'RSVP'
+
+ name: str = 'RSVP Calibration'
+ paradigm: str = 'RSVP'
def init_display(self) -> Display:
+ """Initialize the RSVP display.
+
+ Returns:
+ Display: Configured RSVP calibration display instance.
+ """
return init_calibration_display_task(self.parameters, self.window,
self.static_clock,
self.experiment_clock)
def init_calibration_display_task(
- parameters: Parameters, window: visual.Window,
+ parameters: Parameters,
+ window: visual.Window,
static_clock: core.StaticPeriod,
experiment_clock: Clock) -> CalibrationDisplay:
- """Initialize the display"""
+ """Initialize the RSVP calibration display.
+
+ Args:
+ parameters (Parameters): Task configuration parameters.
+ window (visual.Window): PsychoPy window for display.
+ static_clock (core.StaticPeriod): Clock for static timing.
+ experiment_clock (Clock): Clock for experiment timing.
+
+ Returns:
+ CalibrationDisplay: Configured RSVP calibration display instance.
+ """
info = InformationProperties(
info_color=[parameters['info_color']],
info_pos=[(parameters['info_pos_x'], parameters['info_pos_y'])],
@@ -51,7 +80,8 @@ def init_calibration_display_task(
)
stimuli = StimuliProperties(
stim_font=parameters['font'],
- stim_pos=(parameters['rsvp_stim_pos_x'], parameters['rsvp_stim_pos_y']),
+ stim_pos=(parameters['rsvp_stim_pos_x'],
+ parameters['rsvp_stim_pos_y']),
stim_height=parameters['rsvp_stim_height'],
stim_inquiry=[''] * parameters['stim_length'],
stim_colors=[parameters['stim_color']] * parameters['stim_length'],
@@ -72,7 +102,8 @@ def init_calibration_display_task(
stimuli,
task_bar,
info,
- preview_config=parameters.instantiate(PreviewParams),
+ preview_config=parameters.instantiate(
+ PreviewParams),
trigger_type=parameters['trigger_type'],
space_char=parameters['stim_space_char'],
full_screen=parameters['full_screen'])
diff --git a/bcipy/task/paradigm/rsvp/calibration/timing_verification.py b/bcipy/task/paradigm/rsvp/calibration/timing_verification.py
index f7411189b..579369359 100644
--- a/bcipy/task/paradigm/rsvp/calibration/timing_verification.py
+++ b/bcipy/task/paradigm/rsvp/calibration/timing_verification.py
@@ -1,4 +1,11 @@
# mypy: disable-error-code="assignment"
+"""RSVP timing verification module.
+
+This module provides functionality for verifying display timing in RSVP tasks
+using photodiode stimuli. It alternates between solid and empty boxes that can
+be measured with a photodiode to ensure accurate stimulus presentation timing.
+"""
+
from itertools import cycle, islice, repeat
from typing import Any, Iterator, List
@@ -11,19 +18,20 @@
class RSVPTimingVerificationCalibration(RSVPCalibrationTask):
- """RSVP Calibration Task.
+ """RSVP Timing Verification Task.
- This task is used for verifying display timing by alternating solid and empty boxes. These
- stimuli can be used with a photodiode to ensure accurate presentations.
+ This task verifies display timing by alternating solid and empty boxes.
+ The stimuli can be measured with a photodiode to ensure accurate
+ presentation timing.
- Input:
- parameters (Parameters)
- file_save (str)
- fake (bool)
-
- Output:
- TaskData
+ Attributes:
+ name: Name of the task.
+ mode: Task execution mode.
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
"""
+
name = 'RSVP Timing Verification'
mode = TaskMode.TIMING_VERIFICATION
@@ -32,6 +40,14 @@ def __init__(self,
file_save: str,
fake: bool = False,
**kwargs: Any) -> None:
+ """Initialize the RSVP timing verification task.
+
+ Args:
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
+ **kwargs: Additional keyword arguments.
+ """
parameters['rsvp_stim_height'] = 0.8
parameters['rsvp_stim_pos_y'] = 0.0
super(RSVPTimingVerificationCalibration,
@@ -39,10 +55,22 @@ def __init__(self,
@property
def symbol_set(self) -> List[str]:
- """Symbols used in the calibration"""
+ """Get symbols used in the calibration.
+
+ Returns:
+ List[str]: List of photodiode stimuli symbols.
+ """
return PhotoDiodeStimuli.list()
def init_inquiry_generator(self) -> Iterator[Inquiry]:
+ """Initialize the inquiry generator for timing verification.
+
+ The generator alternates between solid and empty boxes with specified
+ timing parameters. A fixation point is shown between stimuli.
+
+ Returns:
+ Iterator[Inquiry]: Generator yielding timing verification inquiries.
+ """
params = self.parameters
# alternate between solid and empty boxes
diff --git a/bcipy/task/paradigm/rsvp/copy_phrase.py b/bcipy/task/paradigm/rsvp/copy_phrase.py
index 877262571..68f401f85 100644
--- a/bcipy/task/paradigm/rsvp/copy_phrase.py
+++ b/bcipy/task/paradigm/rsvp/copy_phrase.py
@@ -1,4 +1,4 @@
-# mypy: disable-error-code="arg-type"
+# mypy: disable-error-code="arg-type, override"
import logging
from typing import Any, List, NamedTuple, Optional, Tuple
@@ -178,8 +178,10 @@ def __init__(
self.button_press_error_prob = parameters['preview_inquiry_error_prob']
self.signal_model = self.signal_models[0] if self.signal_models else None
- self.evidence_evaluators = self.init_evidence_evaluators(self.signal_models)
- self.evidence_types = self.init_evidence_types(self.signal_models, self.evidence_evaluators)
+ self.evidence_evaluators = self.init_evidence_evaluators(
+ self.signal_models)
+ self.evidence_types = self.init_evidence_types(
+ self.signal_models, self.evidence_evaluators)
self.file_save = file_save
self.save_session_every_inquiry = True
@@ -201,6 +203,7 @@ def setup(
parameters: Parameters,
data_save_location: str,
fake: bool = False) -> Tuple[ClientManager, List[LslDataServer], Window]:
+ """Set up acquisition and return client manager, data servers, and display."""
# Initialize Acquisition
daq, servers = init_acquisition(
parameters, data_save_location, server=fake)
@@ -212,9 +215,11 @@ def setup(
return daq, servers, display
def get_language_model(self) -> LanguageModel:
+ """Return the initialized language model."""
return init_language_model(self.parameters)
def get_signal_models(self) -> Optional[List[SignalModel]]:
+ """Return the list of signal models, or an empty list if fake."""
if not self.fake:
try:
signal_models = choose_signal_models(
@@ -227,6 +232,7 @@ def get_signal_models(self) -> Optional[List[SignalModel]]:
return []
def cleanup(self):
+ """Clean up resources and save session data."""
self.exit_display()
self.write_offset_trigger()
self.save_session_data()
@@ -256,6 +262,7 @@ def cleanup(self):
logger.exception(str(e))
def save_session_data(self) -> None:
+ """Save the session data and summary to disk."""
self.session.task_summary = TaskSummary(
self.session,
self.parameters["show_preview_inquiry"],
@@ -303,6 +310,7 @@ def init_evidence_types(
self, signal_models: List[SignalModel],
evidence_evaluators: List[EvidenceEvaluator]
) -> List[EvidenceType]:
+ """Initialize evidence types for the simulation."""
evidence_types = [EvidenceType.LM]
evidence_types.extend(
[evaluator.produces for evaluator in evidence_evaluators])
@@ -318,7 +326,8 @@ def default_trigger_handler(self) -> TriggerHandler:
def set(self) -> None:
"""Initialize/reset parameters used in the execute run loop."""
- self.spelled_text = str(self.copy_phrase[0: self.starting_spelled_letters()])
+ self.spelled_text = str(
+ self.copy_phrase[0: self.starting_spelled_letters()])
self.last_selection = ""
self.inq_counter = 0
self.session = Session(
@@ -364,13 +373,15 @@ def validate_parameters(self) -> None:
# ensure all required parameters are provided
for param in RSVPCopyPhraseTask.PARAMETERS_USED:
if param not in self.parameters:
- raise TaskConfigurationException(f"parameter '{param}' is required")
+ raise TaskConfigurationException(
+ f"parameter '{param}' is required")
# ensure data / query parameters are set correctly
buffer_len = self.parameters["task_buffer_length"]
prestim = self.parameters["prestim_length"]
poststim = (
- self.parameters["trial_window"][1] - self.parameters["trial_window"][0]
+ self.parameters["trial_window"][1] -
+ self.parameters["trial_window"][0]
)
if buffer_len < prestim:
raise TaskConfigurationException(
@@ -841,7 +852,8 @@ def exit_display(self) -> None:
self.rsvp.update_task_bar(text=self.spelled_text)
# Say Goodbye!
- self.rsvp.info_text = trial_complete_message(self.window, self.parameters)
+ self.rsvp.info_text = trial_complete_message(
+ self.window, self.parameters)
self.rsvp.draw_static()
self.window.flip()
@@ -911,9 +923,11 @@ def write_trigger_data(
for content_type, client in self.daq.clients_by_type.items():
label = offset_label(content_type.name)
time = (
- client.offset(self.rsvp.first_stim_time) - self.rsvp.first_stim_time
+ client.offset(self.rsvp.first_stim_time) -
+ self.rsvp.first_stim_time
)
- offset_triggers.append(Trigger(label, TriggerType.OFFSET, time))
+ offset_triggers.append(
+ Trigger(label, TriggerType.OFFSET, time))
self.trigger_handler.add_triggers(offset_triggers)
triggers = convert_timing_triggers(
@@ -975,16 +989,19 @@ def __init__(
def as_dict(self) -> dict:
"""Computes the task summary data to append to the session."""
- selections = [inq for inq in self.session.all_inquiries if inq.selection]
+ selections = [
+ inq for inq in self.session.all_inquiries if inq.selection]
correct = [inq for inq in selections if inq.is_correct_decision]
incorrect = [inq for inq in selections if not inq.is_correct_decision]
# Note that SPACE is considered a symbol
- correct_symbols = [inq for inq in correct if inq.selection != BACKSPACE_CHAR]
+ correct_symbols = [
+ inq for inq in correct if inq.selection != BACKSPACE_CHAR]
btn_presses = self.btn_press_count()
sel_count = len(selections)
- switch_per_selection = (btn_presses / sel_count) if sel_count > 0 else 0
+ switch_per_selection = (
+ btn_presses / sel_count) if sel_count > 0 else 0
accuracy = (len(correct) / sel_count) if sel_count > 0 else 0
# Note that minutes includes startup time and any breaks.
@@ -1002,9 +1019,7 @@ def as_dict(self) -> dict:
}
def btn_press_count(self) -> int:
- """Compute the number of times the switch was activated. Returns 0 if
- inquiry preview mode was off or mode was preview-only."""
-
+ """Compute the number of times the switch was activated. Returns 0 if inquiry preview mode was off or mode was preview-only."""
if not self.show_preview or self.preview_mode == 0:
return 0
@@ -1034,7 +1049,8 @@ def switch_response_time(self) -> Optional[float]:
logger.info("Could not compute switch_response_time")
return None
- response_times = [keypress.time - preview.time for preview, keypress in pairs]
+ response_times = [keypress.time -
+ preview.time for preview, keypress in pairs]
count = len(response_times)
return sum(response_times) / count if count > 0 else None
@@ -1066,7 +1082,8 @@ def _init_copy_phrase_display(
)
stimuli = StimuliProperties(
stim_font=parameters["font"],
- stim_pos=(parameters["rsvp_stim_pos_x"], parameters["rsvp_stim_pos_y"]),
+ stim_pos=(parameters["rsvp_stim_pos_x"],
+ parameters["rsvp_stim_pos_y"]),
stim_height=parameters["rsvp_stim_height"],
stim_inquiry=["A"] * parameters["stim_length"],
stim_colors=[parameters["stim_color"]] * parameters["stim_length"],
diff --git a/bcipy/task/paradigm/vep/calibration.py b/bcipy/task/paradigm/vep/calibration.py
index a7c54a18b..52ad1c6b8 100644
--- a/bcipy/task/paradigm/vep/calibration.py
+++ b/bcipy/task/paradigm/vep/calibration.py
@@ -1,4 +1,10 @@
-"""VEP Calibration task-related code"""
+"""VEP calibration task module.
+
+This module provides the VEP (Visual Evoked Potential) calibration task
+implementation. The task presents visual stimuli with different flicker rates
+to calibrate the system's response to visual evoked potentials.
+"""
+
import logging
from typing import Any, Dict, Iterator, List, Optional
@@ -24,17 +30,27 @@
class VEPCalibrationTask(BaseCalibrationTask):
"""VEP Calibration Task.
- A task begins setting up variables --> initializing eeg -->
- awaiting user input to start -->
- setting up stimuli --> highlighting inquiries -->
- saving data
+ This task calibrates the system's response to visual evoked potentials by
+ presenting visual stimuli with different flicker rates.
+
+ Task flow:
+ 1. Setup variables
+ 2. Initialize EEG
+ 3. Await user input
+ 4. Setup stimuli
+ 5. Present flickering inquiries
+ 6. Save data
- PARAMETERS:
- ----------
- parameters (dict)
- file_save (str)
- fake (bool)
+ Attributes:
+ name: Name of the task.
+ paradigm: Name of the paradigm.
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
+ box_colors: List of colors for stimulus boxes.
+ num_boxes: Number of stimulus boxes.
"""
+
name = 'VEP Calibration'
paradigm = 'VEP'
@@ -43,6 +59,14 @@ def __init__(self,
file_save: str,
fake: bool = False,
**kwargs: Any) -> None:
+ """Initialize the VEP calibration task.
+
+ Args:
+ parameters: Task configuration parameters.
+ file_save: Path for saving task data.
+ fake: Whether to run in fake (testing) mode.
+ **kwargs: Additional keyword arguments.
+ """
self.box_colors = [
'#00FF80', '#FFFFB3', '#CB99FF', '#FB8072', '#80B1D3', '#FF8232'
]
@@ -50,14 +74,22 @@ def __init__(self,
super().__init__(parameters, file_save, fake=fake, **kwargs)
def init_display(self) -> VEPDisplay:
- """Initialize the display"""
+ """Initialize the VEP display.
+
+ Returns:
+ VEPDisplay: Configured VEP display instance.
+ """
return init_vep_display(self.parameters, self.window,
self.experiment_clock, self.symbol_set,
self.box_colors,
fake=self.fake)
def init_inquiry_generator(self) -> Iterator[Inquiry]:
- """Initializes a generator that returns inquiries to be presented."""
+ """Initialize the inquiry generator.
+
+ Returns:
+ Iterator[Inquiry]: Generator yielding VEP calibration inquiries.
+ """
parameters = self.parameters
schedule = generate_vep_calibration_inquiries(
alp=self.symbol_set,
@@ -75,12 +107,30 @@ def init_inquiry_generator(self) -> Iterator[Inquiry]:
def trigger_type(self, symbol: str, target: str,
index: int) -> TriggerType:
+ """Get trigger type for a symbol.
+
+ Args:
+ symbol: Presented symbol.
+ target: Target symbol.
+ index: Position in sequence.
+
+ Returns:
+ TriggerType: Type of trigger to use.
+ """
if target == symbol:
return TriggerType.TARGET
return TriggerType.EVENT
def session_task_data(self) -> Dict[str, Any]:
- """Task-specific session data"""
+ """Get task-specific session data.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing box configurations and
+ starting positions.
+
+ Raises:
+ AssertionError: If display is not a VEPDisplay instance.
+ """
assert isinstance(self.display, VEPDisplay)
boxes = [{
"colors": box.colors,
@@ -94,7 +144,18 @@ def session_task_data(self) -> Dict[str, Any]:
def session_inquiry_data(self,
inquiry: Inquiry) -> Optional[Dict[str, Any]]:
- """Defines task-specific session data for each inquiry."""
+ """Get task-specific data for an inquiry.
+
+ Args:
+ inquiry: Inquiry to get data for.
+
+ Returns:
+ Optional[Dict[str, Any]]: Dictionary containing target box index
+ and frequency.
+
+ Raises:
+ AssertionError: If display is not a VEPDisplay instance.
+ """
assert isinstance(self.display, VEPDisplay)
target_box = target_box_index(inquiry)
target_freq = self.display.flicker_rates[
@@ -105,7 +166,14 @@ def session_inquiry_data(self,
}
def stim_labels(self, inquiry: Inquiry) -> List[str]:
- """labels for each stimuli in the session data."""
+ """Get labels for each stimulus in the session data.
+
+ Args:
+ inquiry: Inquiry to get labels for.
+
+ Returns:
+ List[str]: List of stimulus labels.
+ """
target_box = target_box_index(inquiry)
targetness = [TriggerType.NONTARGET for _ in range(self.num_boxes)]
if target_box is not None:
@@ -115,7 +183,14 @@ def stim_labels(self, inquiry: Inquiry) -> List[str]:
def target_box_index(inquiry: Inquiry) -> Optional[int]:
- """Index of the target box."""
+ """Get the index of the target box.
+
+ Args:
+ inquiry: Inquiry to find target box in.
+
+ Returns:
+ Optional[int]: Index of target box if found, None otherwise.
+ """
target_letter, _fixation, *boxes = inquiry.stimuli
for i, box in enumerate(boxes):
if target_letter in box:
@@ -126,7 +201,19 @@ def target_box_index(inquiry: Inquiry) -> Optional[int]:
def init_vep_display(parameters: Parameters, window: visual.Window,
experiment_clock: Clock, symbol_set: List[str],
box_colors: List[str], fake: bool = False) -> VEPDisplay:
- """Initialize the display"""
+ """Initialize the VEP display.
+
+ Args:
+ parameters: Task configuration parameters.
+ window: PsychoPy window for display.
+ experiment_clock: Clock for experiment timing.
+ symbol_set: Set of symbols to display.
+ box_colors: List of colors for stimulus boxes.
+ fake: Whether to run in fake (testing) mode.
+
+ Returns:
+ VEPDisplay: Configured VEP display instance.
+ """
info = InformationProperties(
info_color=[parameters['info_color']],
info_pos=[(parameters['info_pos_x'], parameters['info_pos_y'])],
diff --git a/bcipy/task/paradigm/vep/stim_generation.py b/bcipy/task/paradigm/vep/stim_generation.py
index b39a173c4..aed53e663 100644
--- a/bcipy/task/paradigm/vep/stim_generation.py
+++ b/bcipy/task/paradigm/vep/stim_generation.py
@@ -1,4 +1,10 @@
-"""Functions related to stimuli generation for VEP tasks"""
+"""VEP stimulus generation module.
+
+This module provides functions for generating visual stimuli used in VEP
+(Visual Evoked Potential) tasks. It handles the creation of calibration
+inquiries, stimulus box configurations, and inquiry schedules.
+"""
+
import itertools
import math
import random
@@ -15,30 +21,25 @@ def generate_vep_calibration_inquiries(alp: List[str],
inquiry_count: int = 100,
num_boxes: int = 4,
is_txt: bool = True) -> InquirySchedule:
- """
- Generates VEP inquiries with target letters in all possible positions.
+ """Generate VEP inquiries with target letters in all possible positions.
In the VEP paradigm, all stimuli in the alphabet are displayed in each
inquiry. The symbols with the highest likelihoods are displayed alone
while those with lower likelihoods occur together.
- Parameters
- ----------
- alp(list[str]): stimuli
- timing(list[float]): Task specific timing for generator.
- [target, fixation, stimuli]
- color(list[str]): Task specific color for generator
- [target, fixation, stimuli]
- inquiry_count(int): number of inquiries in a calibration
- num_boxes(int): number of display boxes
- is_txt(bool): whether the stimuli type is text. False would be an image stimuli.
-
- Return
- ------
- schedule_inq(tuple(
- samples[list[list[str]]]: list of inquiries
- timing(list[list[float]]): list of timings
- color(list(list[str])): list of colors)): scheduled inquiries
+ Args:
+ alp: List of stimuli.
+ timing: Task specific timing for generator [target, fixation, stimuli].
+ color: Task specific color for generator [target, fixation, stimuli].
+ inquiry_count: Number of inquiries in a calibration.
+ num_boxes: Number of display boxes.
+ is_txt: Whether the stimuli type is text (False for image stimuli).
+
+ Returns:
+ InquirySchedule: Schedule containing inquiries, timings, and colors.
+
+ Raises:
+ AssertionError: If timing list does not contain exactly 3 values.
"""
if timing is None:
timing = [0.5, 1, 2]
@@ -64,7 +65,20 @@ def generate_vep_inquiries(symbols: List[str],
num_boxes: int = 6,
inquiry_count: int = 20,
is_txt: bool = True) -> List[List[Any]]:
- """Generates inquiries"""
+ """Generate a list of VEP inquiries.
+
+ Args:
+ symbols: List of symbols to use in inquiries.
+ num_boxes: Number of display boxes.
+ inquiry_count: Number of inquiries to generate.
+ is_txt: Whether the stimuli type is text.
+
+ Returns:
+ List[List[Any]]: List of inquiries, where each inquiry contains:
+ - Target symbol
+ - Fixation point
+ - List of symbols for each box
+ """
fixation = get_fixation(is_txt)
target_indices = random_target_positions(inquiry_count,
stim_per_inquiry=num_boxes,
@@ -86,44 +100,29 @@ def stim_per_box(num_symbols: int,
num_boxes: int = 6,
max_empty_boxes: int = 0,
max_single_sym_boxes: int = 4) -> List[int]:
- """Determine the number of stimuli per vep box.
-
- Parameters
- ----------
- num_symbols - number of symbols
- num_boxes - number of boxes
- max_empty_boxes - the maximum number of boxes which won't have any
- symbols within them.
- max_single_sym_boxes - maximum number of boxes with a single symbol
-
- Returns
- -------
- A list of length num_boxes, where each number in the list represents
- the number of symbols that should be in the box at that position.
-
- Post conditions:
- The sum of the list should equal num_symbols. Further, there should
- be at most max_empty_boxes with value of 0 and max_single_sym_boxes
- with a value of 1.
+ """Determine the number of stimuli per VEP box.
+
+ This function distributes symbols across boxes based on rules derived from
+ example sessions. It ensures a balanced distribution while allowing for
+ some empty boxes and boxes with single symbols.
+
+ Args:
+ num_symbols: Number of symbols to distribute.
+ num_boxes: Number of boxes to distribute symbols across.
+ max_empty_boxes: Maximum number of boxes that can be empty.
+ max_single_sym_boxes: Maximum number of boxes that can have a single symbol.
+
+ Returns:
+ List[int]: List where each number represents the number of symbols
+ that should be in the box at that position.
+
+ Notes:
+ - The sum of the returned list equals num_symbols
+ - There will be at most max_empty_boxes with value 0
+ - There will be at most max_single_sym_boxes with value 1
+ - Distribution is based on example sessions from:
+ https://www.youtube.com/watch?v=JNFYSeIIOrw
"""
- # Logic based off of example sessions from:
- # https://www.youtube.com/watch?v=JNFYSeIIOrw
- # [[2, 3, 5, 5, 6, 7],
- # [2, 1, 10, 1, 1, 13],
- # [3, 4, 17, 1, 1, 2],
- # [1, 1, 1, 0, 1, 24],
- # [1, 2, 1, 22, 1, 1],
- # [2, 1, 1, 21, 2, 1],
- # [1, 1, 25, 0, 1, 0]]
- # and
- # [[7, 3, 4, 9, 2, 3],
- # 1, 1, 6, 9, 7, 2],
- # 1, 2, 18, 2, 3, 2],
- # 1, 1, 4, 4, 17, 1],
- # 1, 3, 1, 1, 20, 2],
- # 1, 1, 1, 20, 3, 2],
- # 1, 1, 1, 4, 21, 0]]
-
if max_empty_boxes + max_single_sym_boxes >= num_boxes:
max_empty_boxes = 0
max_single_sym_boxes = num_boxes - 1
@@ -160,26 +159,24 @@ def generate_vep_inquiry(alphabet: List[str],
num_boxes: int = 6,
target: Optional[str] = None,
target_pos: Optional[int] = None) -> List[List[str]]:
- """Generates a single random inquiry.
-
- Parameters
- ----------
- alphabet - list of symbols from which to select.
- num_boxes - number of display areas; symbols will be partitioned into
- these areas.
- target - target symbol for the generated inquiry
- target_pos - box index that should contain the target
-
- Returns
- -------
- An inquiry represented by a list of lists, where each sublist
- represents a display box and contains symbols that should appear in that box.
-
- Post-conditions:
- Symbols will not be repeated and all symbols will be partitioned into
- one of the boxes.
+ """Generate a single random VEP inquiry.
+
+ Args:
+ alphabet: List of symbols to select from.
+ num_boxes: Number of display areas to partition symbols into.
+ target: Target symbol for the inquiry.
+ target_pos: Box index that should contain the target.
+
+ Returns:
+ List[List[str]]: List of lists where each sublist represents a display
+ box and contains symbols that should appear in that box.
+
+ Notes:
+ - Symbols will not be repeated
+ - All symbols will be partitioned into one of the boxes
+ - If target is specified, it will be placed in the lowest count box
+ greater than 0
"""
-
box_counts = stim_per_box(num_symbols=len(alphabet), num_boxes=num_boxes)
assert len(box_counts) == num_boxes
syms = [sym for sym in alphabet]
@@ -189,6 +186,7 @@ def generate_vep_inquiry(alphabet: List[str],
# Move the target to the front so it gets put in the lowest count box
# greater than 0.
syms = swapped(syms, index1=0, index2=syms.index(target))
+
# Put syms in boxes
boxes = []
sym_index = 0
diff --git a/bcipy/task/registry.py b/bcipy/task/registry.py
index 58ad32cff..035b55bbb 100644
--- a/bcipy/task/registry.py
+++ b/bcipy/task/registry.py
@@ -1,50 +1,108 @@
-"""Task Registry ; used to provide task options to the GUI and command line
-tools. User defined tasks can be added to the Registry."""
-from typing import Dict, List, Type
+"""Task Registry module for managing BciPy tasks.
+
+This module provides a registry system for BCI tasks, allowing tasks to be
+dynamically discovered and accessed by the GUI and command line tools.
+User-defined tasks can be added to the Registry.
+"""
+
+from typing import Dict, List, Type, TypeVar
from bcipy.task import Task
+# Type variable for Task subclasses
+T = TypeVar('T', bound=Task)
+
class TaskRegistry:
- registry_dict: Dict[str, Type[Task]]
+ """Registry for managing and accessing BCI tasks.
+
+ This class maintains a registry of all available task types in the system.
+ It automatically discovers and registers all non-abstract Task subclasses
+ when initialized, and provides methods for accessing and managing tasks.
- def __init__(self):
- # Collects all non-abstract subclasses of Task. type ignore is used to work around a mypy bug
- # https://github.com/python/mypy/issues/3115
+ Attributes:
+ registry_dict: Dictionary mapping task names to task classes.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the task registry.
+
+ Collects all non-abstract subclasses of Task and registers them.
+ Imports task modules to ensure all tasks are discovered.
+ """
+ # Import task modules to ensure all tasks are discovered
from bcipy.task import actions # noqa
from bcipy.task.paradigm import matrix, rsvp, vep # noqa
- self.registry_dict = {}
+ self.registry_dict: Dict[str, Type[Task]] = {}
self.collect_subclasses(Task) # type: ignore[type-abstract]
- def collect_subclasses(self, cls: Type[Task]):
- """Recursively collects all non-abstract subclasses of the given class and adds them to the registry."""
+ def collect_subclasses(self, cls: Type[T]) -> None:
+ """Recursively collect and register non-abstract subclasses.
+
+ Args:
+ cls: The base class to collect subclasses from.
+
+ Note:
+ Subclasses are only registered if they have no abstract methods.
+ """
for sub_class in cls.__subclasses__():
+ # Only register non-abstract subclasses
if not getattr(sub_class, '__abstractmethods__', False):
- self.registry_dict[sub_class.name] = sub_class
+ if hasattr(sub_class, 'name'):
+ self.registry_dict[sub_class.name] = sub_class
+ else:
+ raise ValueError(f'Task class {sub_class} missing name attribute')
self.collect_subclasses(sub_class)
def get(self, task_name: str) -> Type[Task]:
- """Returns a task type based on its name property."""
+ """Get a task class by its name.
+
+ Args:
+ task_name: Name of the task to retrieve.
+
+ Returns:
+ Type[Task]: The task class.
+
+ Raises:
+ ValueError: If the task name is not registered.
+ """
if task_name in self.registry_dict:
return self.registry_dict[task_name]
raise ValueError(f'{task_name} not a registered task')
def get_all_types(self) -> List[Type[Task]]:
- """Returns a list of all registered tasks."""
+ """Get all registered task classes.
+
+ Returns:
+ List[Type[Task]]: List of all registered task classes.
+ """
return list(self.registry_dict.values())
def list(self) -> List[str]:
- """Returns a list of all registered task names."""
+ """Get names of all registered tasks.
+
+ Returns:
+ List[str]: List of registered task names.
+ """
return list(self.registry_dict.keys())
def calibration_tasks(self) -> List[Type[Task]]:
- """Returns a list of all registered calibration tasks."""
+ """Get all registered calibration tasks.
+
+ Returns:
+ List[Type[Task]]: List of registered calibration task classes.
+ """
from bcipy.task.calibration import BaseCalibrationTask
return [task for task in self.get_all_types() if issubclass(task, BaseCalibrationTask)]
def register_task(self, task: Type[Task]) -> None:
- """Registers a task with the TaskRegistry."""
- # Note that all imported tasks are automatically registered when the TaskRegistry is initialized. This
- # method allows for the registration of additional tasks after initialization.
+ """Register a new task with the registry.
+
+ This method allows registration of additional tasks after initialization.
+ Tasks imported during initialization are automatically registered.
+
+ Args:
+ task: The task class to register.
+ """
self.registry_dict[task.name] = task
diff --git a/bcipy/task/tests/core/test_actions.py b/bcipy/task/tests/core/test_actions.py
index 8f0233338..d4442f4b2 100644
--- a/bcipy/task/tests/core/test_actions.py
+++ b/bcipy/task/tests/core/test_actions.py
@@ -48,7 +48,8 @@ def test_code_hook_action_no_subprocess(self) -> None:
def test_offline_analysis_action(self) -> None:
cmd_expected = f'bcipy-train -p "{self.parameters_path}"'
- when(subprocess).run(cmd_expected, shell=True, check=True).thenReturn(None)
+ when(subprocess).run(cmd_expected,
+ shell=True, check=True).thenReturn(None)
action = OfflineAnalysisAction(
parameters=self.parameters,
data_directory=self.data_directory,
@@ -60,7 +61,8 @@ def test_offline_analysis_action(self) -> None:
def test_experiment_field_collection_action(self) -> None:
experiment_id = 'experiment_id'
- when(actions).start_experiment_field_collection_gui(experiment_id, self.data_directory).thenReturn(None)
+ when(actions).start_experiment_field_collection_gui(
+ experiment_id, self.data_directory).thenReturn(None)
action = ExperimentFieldCollectionAction(
parameters=self.parameters,
data_directory=self.data_directory,
@@ -69,7 +71,8 @@ def test_experiment_field_collection_action(self) -> None:
task_data = action.execute()
self.assertIsNotNone(task_data)
self.assertIsInstance(task_data, TaskData)
- verify(actions, times=1).start_experiment_field_collection_gui(experiment_id, self.data_directory)
+ verify(actions, times=1).start_experiment_field_collection_gui(
+ experiment_id, self.data_directory)
if __name__ == '__main__':
diff --git a/bcipy/task/tests/core/test_handler.py b/bcipy/task/tests/core/test_handler.py
index c87e76551..020527198 100644
--- a/bcipy/task/tests/core/test_handler.py
+++ b/bcipy/task/tests/core/test_handler.py
@@ -157,7 +157,8 @@ def test_prepare_stimuli(self):
len(stimuli[0][0]))
for i in range(1, len(stimuli[0][0])):
self.assertIn(stimuli[0][0][i], self.decision_maker.alphabet)
- self.assertEqual(stimuli[1][0][0:2], self.decision_maker.stimuli_timing)
+ self.assertEqual(stimuli[1][0][0:2],
+ self.decision_maker.stimuli_timing)
class TestDecisionMakerOld(unittest.TestCase):
diff --git a/bcipy/task/tests/core/test_task_main.py b/bcipy/task/tests/core/test_task_main.py
deleted file mode 100644
index 3b29b666a..000000000
--- a/bcipy/task/tests/core/test_task_main.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import unittest
-
-from bcipy.task import Task, TaskData, TaskMode
-
-
-class TestTask(unittest.TestCase):
-
- def test_task_fails_without_name(self):
- mode = TaskMode.CALIBRATION
-
- class TestTask(Task):
-
- def execute(self) -> TaskData:
- ...
-
- with self.assertRaises(AssertionError):
- TestTask(mode=mode)
-
- def test_task_fails_without_mode(self):
- name = "test task"
-
- class TestTask(Task):
-
- def execute(self) -> TaskData:
- ...
-
- with self.assertRaises(AssertionError):
- TestTask(name=name)
-
- def test_task_fails_without_execute(self):
- name = "test task"
- mode = TaskMode.CALIBRATION
-
- class TestTask(Task):
- pass
-
- with self.assertRaises(TypeError):
- TestTask(name=name, mode=mode)
-
- def test_task_initializes(self):
- name = "test task"
- mode = TaskMode.CALIBRATION
-
- class TestTask(Task):
-
- def __init__(self, name: str, mode: TaskMode):
- self.name = name
- self.mode = mode
-
- def execute(self) -> TaskData:
- ...
- task = TestTask(name=name, mode=mode)
- self.assertEqual(task.name, name)
- self.assertEqual(task.mode, mode)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/bcipy/task/tests/orchestrator/test_orchestrator.py b/bcipy/task/tests/orchestrator/test_orchestrator.py
index 2e220acc6..3e14699e7 100644
--- a/bcipy/task/tests/orchestrator/test_orchestrator.py
+++ b/bcipy/task/tests/orchestrator/test_orchestrator.py
@@ -27,8 +27,10 @@ def test_orchestrator_add_task(self) -> None:
task = mock(spec=Task)
task.name = "test task"
task.mode = "test mode"
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn()
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn()
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
orchestrator = SessionOrchestrator()
self.assertTrue(orchestrator.tasks == [])
orchestrator.add_task(task)
@@ -45,8 +47,10 @@ def test_orchestrator_add_tasks(self) -> None:
task2.name = "test task"
task2.mode = "test mode"
tasks = [task1, task2]
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn()
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn()
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
orchestrator = SessionOrchestrator()
self.assertTrue(orchestrator.tasks == [])
orchestrator.add_tasks(tasks)
@@ -63,8 +67,10 @@ def test_orchestrator_execute(self) -> None:
task.name = "test task"
task.mode = "test mode"
task.execute = lambda: TaskData()
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn()
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn()
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator)._init_task_save_folder(any()).thenReturn()
when(SessionOrchestrator)._init_task_logger(any()).thenReturn()
when(SessionOrchestrator)._save_data().thenReturn()
@@ -105,12 +111,15 @@ def test_orchestrator_execute(self) -> None:
@mock_open(read_data='{"Phrases": []}')
def test_orchestrator_multiple_copyphrases_loads_from_parameters_when_set(self, mock_file):
- parameters = load_json_parameters(self.parameter_location, value_cast=True)
+ parameters = load_json_parameters(
+ self.parameter_location, value_cast=True)
copy_phrase_location = "bcipy/parameters/experiments/phrases.json"
parameters['copy_phrases_location'] = copy_phrase_location
mock_copy_phrases = {"Phrases": [["test", 0], ["test2", 1]]}
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn()
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn()
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator)._init_task_save_folder(any()).thenReturn()
when(SessionOrchestrator)._init_task_logger(any()).thenReturn()
when(SessionOrchestrator)._save_data().thenReturn()
@@ -118,30 +127,35 @@ def test_orchestrator_multiple_copyphrases_loads_from_parameters_when_set(self,
orchestrator = SessionOrchestrator(parameters=parameters)
- self.assertEqual(orchestrator.copyphrases, mock_copy_phrases['Phrases'])
+ self.assertEqual(orchestrator.copyphrases,
+ mock_copy_phrases['Phrases'])
verify(json, times=1).load(mock_file)
def test_orchestrator_save_data_multiple_copyphrases_saves_remaining_phrases(self):
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn()
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn()
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator)._init_task_save_folder(any()).thenReturn()
when(SessionOrchestrator)._init_task_logger(any()).thenReturn()
- when(SessionOrchestrator)._save_procotol_data().thenReturn()
+ when(SessionOrchestrator)._save_protocol_data().thenReturn()
when(SessionOrchestrator)._save_copy_phrases().thenReturn()
orchestrator = SessionOrchestrator()
orchestrator.copyphrases = [["test", 0], ["test2", 1]]
orchestrator._save_data()
- verify(SessionOrchestrator, times=1)._save_procotol_data()
+ verify(SessionOrchestrator, times=1)._save_protocol_data()
verify(SessionOrchestrator, times=1)._save_copy_phrases()
def test_orchestrator_next_phrase(self):
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn()
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn()
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator)._init_task_save_folder(any()).thenReturn()
when(SessionOrchestrator)._init_task_logger(any()).thenReturn()
- when(SessionOrchestrator)._save_procotol_data().thenReturn()
+ when(SessionOrchestrator)._save_protocol_data().thenReturn()
when(SessionOrchestrator).initialize_copy_phrases().thenReturn()
orchestrator = SessionOrchestrator()
diff --git a/bcipy/task/tests/orchestrator/test_protocol.py b/bcipy/task/tests/orchestrator/test_protocol.py
index ec258d71d..b2f65f686 100644
--- a/bcipy/task/tests/orchestrator/test_protocol.py
+++ b/bcipy/task/tests/orchestrator/test_protocol.py
@@ -71,6 +71,7 @@ def test_serializes_one_task(self) -> None:
assert serialized == RSVPCalibrationTask.name
def test_serializes_multiple_tasks(self) -> None:
- sequence = [RSVPCalibrationTask, OfflineAnalysisAction, RSVPCopyPhraseTask]
+ sequence = [RSVPCalibrationTask,
+ OfflineAnalysisAction, RSVPCopyPhraseTask]
serialized = serialize_protocol(sequence)
assert serialized == 'RSVP Calibration -> OfflineAnalysisAction -> RSVP Copy Phrase'
diff --git a/bcipy/task/tests/paradigm/rsvp/calibration/test_rsvp_calibration.py b/bcipy/task/tests/paradigm/rsvp/calibration/test_rsvp_calibration.py
index 9fcbbc339..87523e0ea 100644
--- a/bcipy/task/tests/paradigm/rsvp/calibration/test_rsvp_calibration.py
+++ b/bcipy/task/tests/paradigm/rsvp/calibration/test_rsvp_calibration.py
@@ -400,7 +400,8 @@ def test_cleanup(self, save_session_mock, trigger_handler_mock):
(self.daq, self.servers, self.win))
# Mock the default cleanup
- when(bcipy.task.calibration.BaseCalibrationTask).write_offset_trigger().thenReturn(None)
+ when(bcipy.task.calibration.BaseCalibrationTask).write_offset_trigger(
+ ).thenReturn(None)
when(bcipy.task.calibration.BaseCalibrationTask).exit_display().thenReturn(None)
when(bcipy.task.calibration.BaseCalibrationTask).wait().thenReturn(None)
@@ -421,8 +422,10 @@ def test_cleanup(self, save_session_mock, trigger_handler_mock):
verify(self.daq, times=1).cleanup()
verify(self.servers[0], times=1).stop()
verify(self.win, times=1).close()
- verify(bcipy.task.calibration.BaseCalibrationTask, times=1).setup(any(), any(), any())
- verify(bcipy.task.calibration.BaseCalibrationTask, times=1).write_offset_trigger()
+ verify(bcipy.task.calibration.BaseCalibrationTask,
+ times=1).setup(any(), any(), any())
+ verify(bcipy.task.calibration.BaseCalibrationTask,
+ times=1).write_offset_trigger()
verify(bcipy.task.calibration.BaseCalibrationTask, times=1).exit_display()
verify(bcipy.task.calibration.BaseCalibrationTask, times=1).wait()
diff --git a/bcipy/task/tests/paradigm/rsvp/test_copy_phrase.py b/bcipy/task/tests/paradigm/rsvp/test_copy_phrase.py
index 5c75a8565..2e2f86145 100644
--- a/bcipy/task/tests/paradigm/rsvp/test_copy_phrase.py
+++ b/bcipy/task/tests/paradigm/rsvp/test_copy_phrase.py
@@ -111,7 +111,8 @@ def setUp(self):
}
})
self.servers = [mock()]
- when(self.daq).get_client(ContentType.EEG).thenReturn(self.eeg_client_mock)
+ when(self.daq).get_client(
+ ContentType.EEG).thenReturn(self.eeg_client_mock)
self.temp_dir = tempfile.mkdtemp()
self.model_metadata = mock({
'device_spec': device_spec,
@@ -464,7 +465,8 @@ def test_execute_fake_data_with_preview(self, process_data_mock, message_mock,
# Assertions
verify(self.copy_phrase_wrapper, times=2).initialize_series()
verify(self.display, times=1).do_inquiry()
- verify(self.copy_phrase_wrapper, times=1).add_evidence(EvidenceType.BTN, ...)
+ verify(self.copy_phrase_wrapper, times=1).add_evidence(
+ EvidenceType.BTN, ...)
self.assertEqual(self.temp_dir, result.save_path)
@patch('bcipy.task.paradigm.rsvp.copy_phrase.init_evidence_evaluator')
@@ -655,7 +657,8 @@ def test_btn_evidence_with_preview_only(self):
when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).setup(any(), any(), any()).thenReturn(
(self.daq, self.servers, self.win))
self.parameters['show_preview_inquiry'] = True
- self.parameters['preview_inquiry_progress_method'] = 0 # ButtonPressMode.NOTHING.value
+ # ButtonPressMode.NOTHING.value
+ self.parameters['preview_inquiry_progress_method'] = 0
task = RSVPCopyPhraseTask(
parameters=self.parameters,
@@ -693,10 +696,14 @@ def test_cleanup(self):
(self.daq, self.servers, self.win))
# Mock the default cleanup
- when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).write_offset_trigger().thenReturn(None)
- when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).exit_display().thenReturn(None)
- when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).save_session_data().thenReturn(None)
- when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).wait().thenReturn(None)
+ when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).write_offset_trigger(
+ ).thenReturn(None)
+ when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).exit_display(
+ ).thenReturn(None)
+ when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).save_session_data(
+ ).thenReturn(None)
+ when(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask).wait(
+ ).thenReturn(None)
# Mock the initialized cleanup
when(self.daq).stop_acquisition().thenReturn(None)
@@ -716,9 +723,12 @@ def test_cleanup(self):
verify(self.daq, times=1).cleanup()
verify(self.servers[0], times=1).stop()
verify(self.win, times=1).close()
- verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask, times=1).setup(any(), any(), any())
- verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask, times=1).write_offset_trigger()
- verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask, times=1).exit_display()
+ verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask,
+ times=1).setup(any(), any(), any())
+ verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask,
+ times=1).write_offset_trigger()
+ verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask,
+ times=1).exit_display()
verify(bcipy.task.paradigm.rsvp.copy_phrase.RSVPCopyPhraseTask, times=1).wait()
diff --git a/bcipy/tests/test_bci_main.py b/bcipy/tests/test_bci_main.py
index 76edfa019..348cd8332 100644
--- a/bcipy/tests/test_bci_main.py
+++ b/bcipy/tests/test_bci_main.py
@@ -42,7 +42,8 @@ def test_bci_main_fails_without_experiment_or_task(self) -> None:
)
def test_bcipy_main_fails_with_invalid_experiment(self) -> None:
- when(main).validate_bcipy_session(any(), any()).thenRaise(UnregisteredExperimentException)
+ when(main).validate_bcipy_session(any(), any()).thenRaise(
+ UnregisteredExperimentException)
with self.assertRaises(UnregisteredExperimentException):
bci_main(
parameter_location=self.parameters_path,
@@ -54,13 +55,16 @@ def test_bcipy_main_fails_with_invalid_experiment(self) -> None:
)
def test_bci_main_runs_with_valid_experiment(self) -> None:
- when(main).validate_bcipy_session(any(), any()).thenReturn(True) # Mock the validate_bcipy_session function
+ when(main).validate_bcipy_session(any(), any()).thenReturn(
+ True) # Mock the validate_bcipy_session function
when(main).load_json_parameters(
any(), value_cast=any()).thenReturn(
self.parameters) # Mock the load_json_parameters function
when(SessionOrchestrator).get_system_info().thenReturn(None)
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn(None)
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn(None)
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator).initialize_copy_phrases().thenReturn(None)
when(SessionOrchestrator).add_tasks(any()).thenReturn(None)
when(SessionOrchestrator).execute().thenReturn(None)
@@ -83,10 +87,13 @@ def test_bci_main_runs_with_valid_experiment(self) -> None:
def test_bci_main_runs_with_valid_task(self) -> None:
when(main).validate_bcipy_session(any(), any()).thenReturn(True)
- when(main).load_json_parameters(any(), value_cast=any()).thenReturn(self.parameters)
+ when(main).load_json_parameters(
+ any(), value_cast=any()).thenReturn(self.parameters)
when(SessionOrchestrator).get_system_info().thenReturn(None)
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn(None)
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn(None)
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator).initialize_copy_phrases().thenReturn(None)
when(SessionOrchestrator).add_tasks(any()).thenReturn(None)
when(SessionOrchestrator).execute().thenReturn(None)
@@ -110,10 +117,13 @@ def test_bci_main_runs_with_valid_task(self) -> None:
def test_bci_main_returns_false_with_orchestrator_execute_exception(self):
when(main).validate_bcipy_session(any(), any()).thenReturn(True)
- when(main).load_json_parameters(any(), value_cast=any()).thenReturn(self.parameters)
+ when(main).load_json_parameters(
+ any(), value_cast=any()).thenReturn(self.parameters)
when(SessionOrchestrator).get_system_info().thenReturn(None)
- when(SessionOrchestrator)._init_orchestrator_save_folder(any()).thenReturn(None)
- when(SessionOrchestrator)._init_orchestrator_logger(any()).thenReturn(self.logger)
+ when(SessionOrchestrator)._init_orchestrator_save_folder(
+ any()).thenReturn(None)
+ when(SessionOrchestrator)._init_orchestrator_logger(
+ any()).thenReturn(self.logger)
when(SessionOrchestrator).initialize_copy_phrases().thenReturn(None)
when(SessionOrchestrator).add_tasks(any()).thenReturn(None)
when(SessionOrchestrator).execute().thenRaise(Exception)
diff --git a/pyproject.toml b/pyproject.toml
index a57c6768b..c7c9cd050 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,7 @@ authors = [
]
description = "Python Software for Brain-Computer Interface Development."
readme = "README.md"
-requires-python = ">3.7,<3.11"
+requires-python = ">3.8,<3.11"
classifiers = [
'License :: Other/Proprietary License',
'Topic :: Scientific/Engineering :: Human Machine Interfaces',
@@ -18,7 +18,6 @@ classifiers = [
'Intended Audience :: End Users/Desktop',
'Intended Audience :: Developers',
'Programming Language :: Python',
- 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
]
@@ -70,6 +69,7 @@ dev = [
"coverage>=7.0",
"flake8==5.0.4",
"Flake8-pyproject==1.2.3",
+ "flake8-docstrings==1.7.0",
"mypy==1.13",
"lxml",
"mock",
@@ -81,6 +81,7 @@ dev = [
release = [
"twine==3.2.0",
"build==1.2.2.post1",
+ "pyinstaller==6.13.0",
"wheel==0.43.0",
]
@@ -149,8 +150,16 @@ max_line_length = 120
application_import_names = "bcipy"
ignore = [
+ "D100",
+ "D105",
+ "D107",
+ "D202",
+ "D204",
"D205",
"D400",
+ "D401",
+ "D403",
+ "D412",
"F841",
"F821",
"E402",
@@ -159,6 +168,13 @@ ignore = [
"W503",
"W504",
]
+exclude = [
+ "tests",
+ "__init__.py",
+ "demo",
+ "signal", # TODO: Remove when fixed
+ "gui",
+]
[tool.isort]
skip = [".gitignore"]