diff --git a/README.md b/README.md
index 6bdd1ba..6f4d43f 100644
--- a/README.md
+++ b/README.md
@@ -7,12 +7,18 @@
+
+
+
+
+
+
@@ -32,29 +38,34 @@ This repository is built on PyTorch Lightning and Hydra to enable reproducible a
Looking for ready-to-use weights of models? We host them on Hugging Face:
-### Currently available:
+### Currently available
+
- **FEMBA** ([paper](https://arxiv.org/abs/2502.06438)) [](https://huggingface.co/thorir/FEMBA)
- **LUNA** ([paper](https://arxiv.org/abs/2510.22257)) [](https://huggingface.co/thorir/LUNA)
-
+- **TinyMyo** ([paper](https://arxiv.org/abs/2512.15729)) [](https://huggingface.co/MatteoFasulo/TinyMyo)
#### Why FEMBA?
+
- **Scales to long EEG** with linear-time Mamba (no quadratic attention).
- **Strong results** on TUAB/TUAR/TUSL with ready task-specific checkpoints.
- **Simple fine-tune path:** set `CHECKPOINT_DIR`, run `+experiment=FEMBA_finetune`.
-**➡️ Model hub:** https://huggingface.co/thorir/FEMBA
-**📄 Model card:** [FEMBA on Hugging Face](https://huggingface.co/thorir/FEMBA) — benchmarks, protocols, and efficiency notes.
-**📜 Weights license:** CC BY-ND 4.0 (use + redistribute **unmodified** weights with attribution; no redistribution of **modified** weights)
+**➡️ Model hub:**
+**📄 Model card:** [FEMBA on Hugging Face](https://huggingface.co/thorir/FEMBA) — benchmarks, protocols, and efficiency notes.
+**📜 Weights license:** CC BY-ND 4.0 (use + redistribute **unmodified** weights with attribution; no redistribution of **modified** weights)
**🧑🍳 PR-gated improvements:** If you fine-tune internally and want your variant to become an **official** FEMBA release, open a PR with configs, logs, and evals. We’ll review together; if it looks good, we’ll retrain/validate and publish an **official** FEMBA checkpoint.
**What you’ll find on the hub**
+
- `TUAB/` → abnormal EEG (base/large)
- `TUAR/` → artifact detection (tiny/base/large)
- `TUSL/` → slowing classification (variants as in the paper)
Quick download with `huggingface_hub`:
+
```bash
pip install huggingface_hub
```
+
```python
from huggingface_hub import snapshot_download
@@ -63,6 +74,7 @@ snapshot_download(repo_id="thorir/FEMBA", repo_type="model", local_dir="checkpoi
```
Use the paths directly in your runs, e.g.:
+
```bash
export DATA_PATH=/path/to/data
export CHECKPOINT_DIR=checkpoints/FEMBA/TUAR/base.safetensors
@@ -70,24 +82,28 @@ python -u run_train.py +experiment=FEMBA_finetune
```
#### Why LUNA?
+
- **Topology-agnostic** EEG via **query-based channel unification** (consistent latent across arbitrary montages).
- **Linear-in-channels** compute & memory (unifies channels **before** temporal modeling; no quadratic spatio-temporal attention).
- **Pretrained on >21k hours** (TUEG + Siena) with masked-patch reconstruction; strong transfer across datasets/montages.
- **Simple fine-tune path:** pick model size with `LUNA_{base,large,huge}.yaml`, set `pretrained_safetensors_path`, run `+experiment=LUNA_finetune`.
-**➡️ Model hub:** https://huggingface.co/thorir/LUNA
-**📄 Model card:** [LUNA on Hugging Face](https://huggingface.co/thorir/LUNA) — variants, configs, and fine-tuning walkthrough.
-**📜 Weights license:** CC BY-ND 4.0 (use + redistribute **unmodified** weights with attribution; no redistribution of **modified** weights)
+**➡️ Model hub:**
+**📄 Model card:** [LUNA on Hugging Face](https://huggingface.co/thorir/LUNA) — variants, configs, and fine-tuning walkthrough.
+**📜 Weights license:** CC BY-ND 4.0 (use + redistribute **unmodified** weights with attribution; no redistribution of **modified** weights)
**🧑🍳 PR-gated improvements:** If you fine-tune internally and want your variant to become an **official** LUNA release, open a PR with configs, logs, and evals. We’ll review; if it looks good, we’ll retrain/validate and publish an **official** LUNA checkpoint.
**What you’ll find on the hub**
+
- `Base/`, `Large/`, `Huge/` → LUNA size variants (matching `config/model/LUNA_{base,large,huge}.yaml`)
- Task-specific heads/checkpoints for common TUH downstream tasks (TUAB / TUAR / TUSL)
Quick download with `huggingface_hub`:
+
```bash
pip install huggingface_hub
```
+
```python
from huggingface_hub import snapshot_download
@@ -96,6 +112,7 @@ snapshot_download(repo_id="thorir/LUNA", repo_type="model", local_dir="checkpoin
```
Use the paths directly in your runs like here below:
+
```bash
python -u run_train.py +experiment=LUNA_finetune /model=LUNA_base \
pretrained_safetensors_path=/absolute/path/to/checkpoints/LUNA/Base/LUNA_base.safetensors
@@ -106,28 +123,80 @@ python -u run_train.py +experiment=LUNA_finetune /model=LUNA_large \
python -u run_train.py +experiment=LUNA_finetune /model=LUNA_huge \
pretrained_safetensors_path=/absolute/path/to/checkpoints/LUNA/Huge/LUNA_huge.safetensors
```
+
*If your checkpoint path contains spaces, wrap it in quotes.*
Tips:
-- TUH datasets (TUAB/TUAR/TUSL): keep `- override /data_module: finetune_data_module` and set `data_module.*.hdf5_file` to your `{train,val,test}.h5`.
-- Non-TUH (e.g., SEED-V): use `- override /data_module: subject_independent_data_module` and remove the TUH-specific `data_module` block.
+
+- TUH datasets (TUAB/TUAR/TUSL): keep `- override /data_module: finetune_data_module` and set `data_module.*.hdf5_file` to your `{train,val,test}.h5`.
+- Non-TUH (e.g., SEED-V): use `- override /data_module: subject_independent_data_module` and remove the TUH-specific `data_module` block.
- Match task settings: `classification_type` (`bc`, `mc`, `mmc`, `mcc`) and `model.num_classes` (e.g., TUSL=4, TUAB=2).
+#### Why TinyMyo?
+
+- **Ultra-lightweight**: only 3.6M parameters, suitable for microcontroller deployment.
+- **Broad generalization**: pretrained on multiple large-scale EMG datasets for versatility across tasks and sensor configurations.
+- **Strong results** on surface EMG tasks with ready task-specific checkpoints.
+
+**➡️ Model hub:**
+**📄 Model card:** [TinyMyo on Hugging Face](https://huggingface.co/MatteoFasulo/TinyMyo) — benchmarks, protocols, and efficiency notes.
+**📜 Weights license:** CC BY-ND 4.0 (use + redistribute **unmodified** weights with attribution; no redistribution of **modified** weights)
+**🧑🍳 PR-gated improvements:** If you fine-tune internally and want your variant to become an **official** TinyMyo release, open a PR with configs, logs, and evals. We’ll review together; if it looks good, we’ll retrain/validate and publish an **official** TinyMyo checkpoint.
+
+**What you’ll find on the hub**
+
+- `DB5/` → gesture classification
+- `UCI_EMG/` → gesture classification
+- `EPN612/` → gesture classification
+
+> The scripts to download and preprocess the datasets are available at:
+
+Quick download with `huggingface_hub`:
+
+```bash
+pip install huggingface_hub
+```
+
+```python
+from huggingface_hub import snapshot_download
+
+# downloads all task folders (DB5/UCI_EMG/EPN612) and safetensors into ./checkpoints/TinyMyo
+snapshot_download(repo_id="MatteoFasulo/TinyMyo", repo_type="model", local_dir="checkpoints/TinyMyo")
+```
+
+Use the paths directly in your runs, e.g.:
+
+```bash
+export DATA_PATH=/path/to/data
+export CHECKPOINT_DIR=checkpoints/TinyMyo/UCI_EMG/base.safetensors
+python -u run_train.py +experiment=TinyMyo_finetune \
+ pretrained_safetensors_path=/path/to/model.safetensors
+```
+
+**What you won’t find on the hub**
+
+- **Silent Speech**
+ - Codebase: [MatteoFasulo/silent_speech](https://github.com/MatteoFasulo/silent_speech)
+- **Generic Neuromotor Interface**
+ - Codebase: [MatteoFasulo/generic-neuromotor-interface](https://github.com/MatteoFasulo/generic-neuromotor-interface)
## Features
-* **Modular Design**: The repository is organized into modules for data loading, models, training tasks, and more, making it easy to extend and adapt for new research projects.
-* **Flexible Configuration**: We use [Hydra](https://hydra.cc/docs/intro/) to manage experiment configurations, allowing for easy customization of models, data, and training parameters.
-* **Reproducibility**: Our use of `Hydra` and PyTorch Lightning helps ensure that our experiments are reproducible.
-* **Extensible**: The repository is designed to be easily extended with new datasets, models, and tasks.
+- **Modular Design**: The repository is organized into modules for data loading, models, training tasks, and more, making it easy to extend and adapt for new research projects.
+- **Flexible Configuration**: We use [Hydra](https://hydra.cc/docs/intro/) to manage experiment configurations, allowing for easy customization of models, data, and training parameters.
+- **Reproducibility**: Our use of `Hydra` and PyTorch Lightning helps ensure that our experiments are reproducible.
+- **Extensible**: The repository is designed to be easily extended with new datasets, models, and tasks.
## Installation
+
To use BioFoundation, clone the repository and install the required dependencies.
```bash
git clone https://github.com/pulp-bio/BioFoundation.git
```
+
We recommend using a virtual environment to manage dependencies. You can use `conda` or `virtualenv` for this purpose. We have provided a `requirements.txt` file that lists the necessary packages. You can install them using pip, and optionally, you can use `conda` to create a new environment.
+
```bash
conda create -n BioFoundation
conda activate BioFoundation
@@ -135,6 +204,7 @@ pip install -r requirements.txt
```
### Path changes
+
Throughout the repository, you may find paths that need to be adjusted based on your local setup. For example, the path to the datasets in the configuration files or the scripts that process the datasets. Make sure to update these paths accordingly. They have been named "#CHANGEME" to facilitate finding them.
## Dataset Preparation
@@ -144,13 +214,16 @@ To prepare the TUH EEG datasets (see the [official source](https://isip.piconepr
1. **Download raw data** from the official sources (e.g., TUH EEG corpus).
2. **Preprocess to pickles** (windowing/labels):
+
```bash
# examples (adjust paths)
python make_datasets/process_raw_eeg.py tuab --root_dir /eeg_data/TUAB/edf --output_dir /processed_eeg
python make_datasets/process_raw_eeg.py tusl --root_dir /eeg_data/TUSL/edf --output_dir /processed_eeg
python make_datasets/process_raw_eeg.py tuar --root_dir /eeg_data/TUAR/edf --output_dir /processed_eeg
```
-3. **Bundle into HDF5:**: Use the provided script to process the raw data into HDF5 files.
+
+3. **Bundle into HDF5:**: Use the provided script to process the raw data into HDF5 files.
+
```bash
# all datasets found under /processed_eeg
python make_datasets/make_hdf5.py --prepath /processed_eeg --dataset All --remove_pkl
@@ -158,11 +231,14 @@ To prepare the TUH EEG datasets (see the [official source](https://isip.piconepr
# or a single dataset
python make_datasets/make_hdf5.py --prepath /processed_eeg --dataset TUSL --remove_pkl
```
+
You may need to edit the `prepath` variable in the script to point to the directory where you have downloaded the raw data.
-4. **Update Configs**: so `data_module.*.hdf5_file` points to your `${DATA_PATH}/_data/{train,val,test}.h5`
+4. **Update Configs**: so `data_module.*.hdf5_file` points to your `${DATA_PATH}/_data/{train,val,test}.h5`
## How to Run
+
### Pre-training
+
To run a pre-training experiment, you can use the `run_train.py` script with the appropriate configuration file. For example in the case of pre-training FEMBA:
```bash
@@ -171,6 +247,7 @@ python -u run_train.py +experiment=FEMBA_pretrain
```
### Fine-tuning
+
To run a fine-tuning experiment, you can use the `run_train.py` script with the appropriate configuration file. For example in the case of fine-tuning FEMBA:
```bash
@@ -178,13 +255,14 @@ python -u run_train.py +experiment=FEMBA_finetune
```
-> **Tip:** Pretrained FEMBA weights (TUAB/TUAR/TUSL folders) are available on 🤗 Hugging Face:
-> https://huggingface.co/thorir/FEMBA
+> **Tip:** Pretrained FEMBA weights (TUAB/TUAR/TUSL folders) are available on 🤗 Hugging Face:
+>
> Set `CHECKPOINT_DIR` to the desired `.safetensors` (e.g., `.../TUAR/base.safetensors`) before launching.
Note in both cases one needs to make sure that the dataset that specific experiment is using is downloaded and available in the correct path.
## Repository Structure
+
```
BioFoundation/
├── config # Hydra configuration files
@@ -197,42 +275,61 @@ BioFoundation/
├── tasks # PyTorch Lightning tasks
└── ...
```
+
## Contributing
+
We welcome contributions to BioFoundation! If you have a new model, dataset, or task that you would like to add, please follow the guidelines below.
+
### How to add a new dataset?
+
1. Add the code of the dataset to [`datasets`](datasets/).
2. Add the configuration file of the dataset to [`./config/dataset`](./config/dataset/).
3. If the dataset is large, consider adding a script to download it in the [`./scripts`](./scripts) directory. Make sure to document how to run the script in the README.
+
### How to add a new data module?
+
1. Add the code of the data module to [`./data_module`](./data_module).
2. Add the configuration file of the data module to [`./config/data_module`](./config/data_module).
3. If the data module requires specific datasets, make sure to document how to download and prepare them in the README.
+
### How to add a new loss function?
+
1. Add the code of the loss function to [`./criterion`](./criterion).
2. Add the configuration file of the loss function to [`./config/criterion`](./config/criterion).
+
### How to add a new task?
+
1. Add the code of the task to [`./tasks`](./tasks).
2. Add the configuration file of the task to [`./config/task`](./config/task).
3. If the task requires specific datasets or models, make sure to document how to download and prepare them in the README.
+
### How to add a new scheduler?
+
1. Add the code of the scheduler to [`./schedulers`](./schedulers).
2. Add the configuration file of the scheduler to [`./config/scheduler`](./config/scheduler).
3. If the scheduler requires specific models or tasks, make sure to document how to use it in the README.
+
### How to add a new model?
+
1. Add the code of the model to [`./models`](./models).
2. Add the configuration file of the model to [`./config/model`](./config/model).
+
### How to start a new experiment with the added model?
-1. Add experiment configuration file to [`./config/experiment`](./config/experiment).
+
+1. Add experiment configuration file to [`./config/experiment`](./config/experiment).
If you are interested, you may check the [Hydra document about it](https://hydra.cc/docs/patterns/configuring_experiments/).
2. Override the default configurations in the added experiment configuration file.
3. Run the experiment with the command:
+
```bash
python -u run_train.py +experiment=your_experiment_name
```
### Contributing improvements to FEMBA weights
-We’re excited to see what you build. Because the weights are **CC BY-ND 4.0**, redistribution of **modified** weights (e.g., LoRA/adapters, deltas, pruned or quantized variants) is **not permitted**.
+
+We’re excited to see what you build. Because the weights are **CC BY-ND 4.0**, redistribution of **modified** weights (e.g., LoRA/adapters, deltas, pruned or quantized variants) is **not permitted**.
If you fine-tune internally and believe your results should become an **official** FEMBA release, please open a PR with:
+
- exact **configs**, **seeds**, and **training scripts**,
- **environment** and **hardware** details,
- **evaluation protocol** (TUAB/TUAR/TUSL), **splits**, and full **metrics** (AUROC/AUPR/BA, FLOPs, memory),
@@ -243,7 +340,9 @@ Maintainers will review; if accepted, we will retrain/validate and publish a new
## General Tips
### How to use distributed data parallel?
+
In your experiment configuration file, add the following arguments
+
```yaml
trainer:
accelerator: gpu # Using GPU
@@ -253,6 +352,7 @@ trainer:
```
### How to save GPU memory?
+
1. Try fairscale checkpointing first. Check [here](https://fairscale.readthedocs.io/en/stable/api/nn/checkpoint/checkpoint_activations.html) and [here](https://github.com/ofsoundof/GRL-Image-Restoration/blob/main/models/networks/grl.py#L134)
2. Use sharded training. Check [here](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html).
@@ -264,16 +364,15 @@ For questions and support, please open an issue on the GitHub repository.
If you find this work useful, please cite the respective papers:
-
```bibtex
@misc{tegon2025fembaefficientscalableeeg,
- title={FEMBA: Efficient and Scalable EEG Analysis with a Bidirectional Mamba Foundation Model},
+ title={FEMBA: Efficient and Scalable EEG Analysis with a Bidirectional Mamba Foundation Model},
author={Anna Tegon and Thorir Mar Ingolfsson and Xiaying Wang and Luca Benini and Yawei Li},
year={2025},
eprint={2502.06438},
archivePrefix={arXiv},
primaryClass={cs.LG},
- url={https://arxiv.org/abs/2502.06438},
+ url={https://arxiv.org/abs/2502.06438},
}
@inproceedings{doner2025luna,
title={{LUNA}: Efficient and Topology-Agnostic Foundation Model for {EEG} Signal Analysis},
@@ -282,11 +381,20 @@ If you find this work useful, please cite the respective papers:
year={2025},
url={https://openreview.net/forum?id=uazfjnFL0G}
}
+
+@misc{fasulo2025tinymyotinyfoundationmodel,
+ title={TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge},
+ author={Matteo Fasulo and Giusy Spacone and Thorir Mar Ingolfsson and Yawei Li and Luca Benini and Andrea Cossettini},
+ year={2025},
+ eprint={2512.15729},
+ archivePrefix={arXiv},
+ primaryClass={eess.SP},
+ url={https://arxiv.org/abs/2512.15729},
+}
```
## License
-This project is licensed under the Apache License 2.0. See the [LICENSE](./LICENSE) file for details.
+This project is licensed under the Apache License 2.0. See the [LICENSE](./LICENSE) file for details.
-**Note on model weights:** Pretrained weights are hosted at https://huggingface.co/thorir/FEMBA and https://huggingface.co/thorir/LUNA and licensed under **CC BY-ND 4.0**. You may use and redistribute the **unmodified** weights with attribution. Redistribution of **modified** weights is not permitted. To upstream improvements, please open a PR; accepted changes will be released as **official** checkpoints.
-
+**Note on model weights:** Pretrained weights are hosted at , , and and licensed under **CC BY-ND 4.0**. You may use and redistribute the **unmodified** weights with attribution. Redistribution of **modified** weights is not permitted. To upstream improvements, please open a PR; accepted changes will be released as **official** checkpoints.
diff --git a/config/data_module/emg_finetune_data_module.yaml b/config/data_module/emg_finetune_data_module.yaml
new file mode 100644
index 0000000..c881b50
--- /dev/null
+++ b/config/data_module/emg_finetune_data_module.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+data_module:
+ _target_: data_module.finetune_data_module.FinetuneDataModule
+ name: "emg"
+ cfg:
+ num_workers: ${num_workers}
+ batch_size: ${batch_size}
+ train:
+ _target_: 'datasets.emg_finetune_dataset.EMGDataset'
+ hdf5_file: ${env:DATA_PATH}/UCI_EMG/EMG_data_for_gestures-master/h5/train.h5
+ finetune: true
+ val:
+ _target_: 'datasets.emg_finetune_dataset.EMGDataset'
+ hdf5_file: ${env:DATA_PATH}/UCI_EMG/EMG_data_for_gestures-master/h5/val.h5
+ finetune: true
+ test:
+ _target_: 'datasets.emg_finetune_dataset.EMGDataset'
+ hdf5_file: ${env:DATA_PATH}/UCI_EMG/EMG_data_for_gestures-master/h5/test.h5
+ finetune: true
diff --git a/config/data_module/emg_pretrain_data_module.yaml b/config/data_module/emg_pretrain_data_module.yaml
new file mode 100644
index 0000000..55b293f
--- /dev/null
+++ b/config/data_module/emg_pretrain_data_module.yaml
@@ -0,0 +1,40 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+data_module:
+ _target_: 'data_module.pretrain_data_module.PretrainDataModule'
+ name: "emg"
+ cfg:
+ num_workers: ${num_workers}
+ batch_size: ${batch_size}
+ test: null
+ train_val_split_ratio: 0.8
+ datasets:
+ demo_dataset: null
+ emg2pose:
+ _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
+ data_dir: "${env:DATA_PATH}/emg2pose/h5/"
+ db6:
+ _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
+ data_dir: "${env:DATA_PATH}/ninapro/DB6/h5/"
+ pad_up_to_max_chans: 16
+ db7:
+ _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
+ data_dir: "${env:DATA_PATH}/ninapro/DB7/h5/"
+ pad_up_to_max_chans: 16
diff --git a/config/experiment/TinyMyo_finetune.yaml b/config/experiment/TinyMyo_finetune.yaml
new file mode 100644
index 0000000..9532c5f
--- /dev/null
+++ b/config/experiment/TinyMyo_finetune.yaml
@@ -0,0 +1,100 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+tag: EMG_finetune
+
+gpus: -1
+num_nodes: 1
+num_workers: 8
+batch_size: 32
+max_epochs: 50
+
+training: True
+final_validate: True
+final_test: True
+finetune_pretrained: True
+resume: False
+
+layerwise_lr_decay: 0.90
+scheduler_type: cosine
+
+pretrained_checkpoint_path: null
+pretrained_safetensors_path: null
+
+finetuning:
+ freeze_layers: False
+
+io:
+ base_output_path: ${env:DATA_PATH}
+ checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints
+ version: 0
+
+defaults:
+ - override /data_module: emg_finetune_data_module
+ - override /model: TinyMyo_finetune
+ - override /scheduler: cosine
+ - override /task: finetune_task_TinyMyo
+ - override /criterion: finetune_criterion
+
+masking:
+ patch_size: [1, 20]
+ masking_ratio: 0.50
+ unmasked_loss_coeff: 0.1
+
+input_normalization:
+ normalize: False
+
+model:
+ num_classes: 6
+ classification_type: "ml"
+
+trainer:
+ accelerator: gpu
+ num_nodes: ${num_nodes}
+ devices: ${gpus}
+ strategy: auto
+ max_epochs: ${max_epochs}
+
+model_checkpoint:
+ save_last: True
+ monitor: "val_loss"
+ mode: "min"
+ save_top_k: 1
+
+callbacks:
+ early_stopping:
+ _target_: 'pytorch_lightning.callbacks.EarlyStopping'
+ monitor: "val_loss"
+ patience: 7
+ mode: "min"
+ verbose: True
+
+optimizer:
+ optim: 'AdamW'
+ lr: 5e-4
+ betas: [0.9, 0.98]
+ weight_decay: 0.01
+
+scheduler:
+ trainer: ${trainer}
+ min_lr: 1e-5
+ warmup_lr_init: 1e-5
+ warmup_epochs: 5
+ total_training_opt_steps: ${max_epochs}
+ t_in_epochs: True
diff --git a/config/experiment/TinyMyo_pretrain.yaml b/config/experiment/TinyMyo_pretrain.yaml
new file mode 100644
index 0000000..3485edd
--- /dev/null
+++ b/config/experiment/TinyMyo_pretrain.yaml
@@ -0,0 +1,79 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+tag: EMG_pretrain
+
+gpus: -1
+num_nodes: 1
+num_workers: 8
+batch_size: 128
+max_epochs: 50
+
+final_validate: True
+final_test: False
+
+pretrained_checkpoint_path: null
+io:
+ base_output_path: ${env:DATA_PATH}
+ checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints
+ version: 0
+
+defaults:
+ - override /data_module: emg_pretrain_data_module
+ - override /model: TinyMyo_pretrain
+ - override /scheduler: cosine
+ - override /task: pretrain_task_TinyMyo
+ - override /criterion: pretrain_criterion
+
+masking:
+ patch_size: [1, 20]
+ masking_ratio: 0.50
+ unmasked_loss_coeff: 0.1
+
+input_normalization:
+ normalize: True
+
+scheduler:
+ trainer: ${trainer}
+ min_lr: 1e-6
+ warmup_lr_init: 1e-6
+ warmup_epochs: 10
+ total_training_opt_steps: ${max_epochs}
+ t_in_epochs: True
+
+trainer:
+ accelerator: gpu
+ num_nodes: ${num_nodes}
+ devices: ${gpus}
+ strategy: auto
+ max_epochs: ${max_epochs}
+ gradient_clip_val: 3
+ accumulate_grad_batches: 8
+
+model_checkpoint:
+ save_last: True
+ monitor: "val_loss"
+ mode: "min"
+ save_top_k: 1
+
+optimizer:
+ optim: 'AdamW'
+ lr: 1e-4
+ betas: [0.9, 0.98]
+ weight_decay: 0.01
diff --git a/config/model/TinyMyo_finetune.yaml b/config/model/TinyMyo_finetune.yaml
new file mode 100644
index 0000000..045cf2f
--- /dev/null
+++ b/config/model/TinyMyo_finetune.yaml
@@ -0,0 +1,33 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+model:
+ _target_: models.TinyMyo.TinyMyo
+ img_size: 1000
+ patch_size: 20
+ in_chans: 16
+ embed_dim: 192
+ n_layer: 8
+ n_head: 3
+ mlp_ratio: 4
+ qkv_bias: True
+ attn_drop: 0.1
+ proj_drop: 0.1
+ drop_path: 0.1
+ num_classes: ${num_classes}
diff --git a/config/model/TinyMyo_pretrain.yaml b/config/model/TinyMyo_pretrain.yaml
new file mode 100644
index 0000000..eb489d5
--- /dev/null
+++ b/config/model/TinyMyo_pretrain.yaml
@@ -0,0 +1,33 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+model:
+ _target_: models.TinyMyo.TinyMyo
+ img_size: 1000
+ patch_size: 20
+ in_chans: 16
+ embed_dim: 192
+ n_layer: 8
+ n_head: 3
+ mlp_ratio: 4
+ qkv_bias: True
+ attn_drop: 0.1
+ proj_drop: 0.1
+ drop_path: 0.1
+ num_classes: 0
diff --git a/config/task/finetune_task_TinyMyo.yaml b/config/task/finetune_task_TinyMyo.yaml
new file mode 100644
index 0000000..c414998
--- /dev/null
+++ b/config/task/finetune_task_TinyMyo.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+task:
+ _target_: 'tasks.finetune_task_EMG.FinetuneTask'
diff --git a/config/task/pretrain_task_TinyMyo.yaml b/config/task/pretrain_task_TinyMyo.yaml
new file mode 100644
index 0000000..fe52f1a
--- /dev/null
+++ b/config/task/pretrain_task_TinyMyo.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2025 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Matteo Fasulo *
+#*----------------------------------------------------------------------------*
+task:
+ _target_: 'tasks.pretrain_task_EMG.MaskTask'
diff --git a/data_module/finetune_data_module.py b/data_module/finetune_data_module.py
index 82777f9..a3d25d5 100644
--- a/data_module/finetune_data_module.py
+++ b/data_module/finetune_data_module.py
@@ -1,22 +1,22 @@
-#*----------------------------------------------------------------------------*
-#* Copyright (C) 2025 ETH Zurich, Switzerland *
-#* SPDX-License-Identifier: Apache-2.0 *
-#* *
-#* Licensed under the Apache License, Version 2.0 (the "License"); *
-#* you may not use this file except in compliance with the License. *
-#* You may obtain a copy of the License at *
-#* *
-#* http://www.apache.org/licenses/LICENSE-2.0 *
-#* *
-#* Unless required by applicable law or agreed to in writing, software *
-#* distributed under the License is distributed on an "AS IS" BASIS, *
-#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
-#* See the License for the specific language governing permissions and *
-#* limitations under the License. *
-#* *
-#* Author: Anna Tegon *
-#* Author: Thorir Mar Ingolfsson *
-#*----------------------------------------------------------------------------*
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Anna Tegon *
+# * Author: Thorir Mar Ingolfsson *
+# *----------------------------------------------------------------------------*
from typing import Optional
import pytorch_lightning as pl
@@ -76,7 +76,7 @@ def train_dataloader(self):
batch_size=self.cfg.batch_size,
shuffle=True,
num_workers=self.cfg.num_workers,
- drop_last=True, # Drop last incomplete batch to keep batch sizes consistent
+ drop_last=True, # Drop last incomplete batch to keep batch sizes consistent
pin_memory=True,
)
@@ -105,7 +105,7 @@ def test_dataloader(self):
drop_last=False,
pin_memory=True,
)
-
+
def predict_dataloader(self):
"""
Returns the DataLoader for prediction with shuffling disabled.
diff --git a/data_module/pretrain_data_module.py b/data_module/pretrain_data_module.py
index 2680577..a54e551 100644
--- a/data_module/pretrain_data_module.py
+++ b/data_module/pretrain_data_module.py
@@ -1,30 +1,28 @@
-#*----------------------------------------------------------------------------*
-#* Copyright (C) 2025 ETH Zurich, Switzerland *
-#* SPDX-License-Identifier: Apache-2.0 *
-#* *
-#* Licensed under the Apache License, Version 2.0 (the "License"); *
-#* you may not use this file except in compliance with the License. *
-#* You may obtain a copy of the License at *
-#* *
-#* http://www.apache.org/licenses/LICENSE-2.0 *
-#* *
-#* Unless required by applicable law or agreed to in writing, software *
-#* distributed under the License is distributed on an "AS IS" BASIS, *
-#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
-#* See the License for the specific language governing permissions and *
-#* limitations under the License. *
-#* *
-#* Author: Anna Tegon *
-#* Author: Thorir Mar Ingolfsson *
-#*----------------------------------------------------------------------------*
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Anna Tegon *
+# * Author: Thorir Mar Ingolfsson *
+# *----------------------------------------------------------------------------*
from typing import Optional
+
import pytorch_lightning as pl
-from torch.utils.data import (
- DataLoader,
- ConcatDataset,
- Dataset,
-)
import torch
+from torch.utils.data import ConcatDataset, DataLoader
+
class PretrainDataModule(pl.LightningDataModule):
"""
@@ -44,12 +42,12 @@ class PretrainDataModule(pl.LightningDataModule):
def __init__(
self,
- datasets: [torch.utils.data.Dataset],
+ datasets: [torch.utils.data.Dataset],
test=None,
cfg=None,
name="",
train_val_split_ratio=0.8,
- **kwargs
+ **kwargs,
):
super().__init__()
@@ -82,7 +80,7 @@ def __init__(
self.name = name
self.cfg = cfg
self.batch_size = self.cfg.batch_size
-
+
print(len(self.train), len(self.val))
def setup(self, stage: Optional[str] = None):
@@ -122,9 +120,10 @@ def train_dataloader(self):
self.train_dataset,
num_workers=self.cfg.num_workers,
pin_memory=True,
+ persistent_workers=True,
shuffle=True,
batch_size=self.batch_size,
- drop_last=True
+ drop_last=True,
)
def val_dataloader(self):
@@ -145,7 +144,8 @@ def val_dataloader(self):
self.val_dataset,
num_workers=self.cfg.num_workers,
pin_memory=True,
+ persistent_workers=True,
shuffle=False,
batch_size=self.batch_size,
- drop_last=True
+ drop_last=True,
)
diff --git a/datasets/emg_finetune_dataset.py b/datasets/emg_finetune_dataset.py
new file mode 100644
index 0000000..e6ad10f
--- /dev/null
+++ b/datasets/emg_finetune_dataset.py
@@ -0,0 +1,145 @@
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
+from collections import deque
+from typing import Tuple, Union
+
+import h5py
+import torch
+
+
+class EMGDataset(torch.utils.data.Dataset):
+ """
+ A PyTorch Dataset class for loading EMG (Electromyography) data from HDF5 files.
+ This dataset supports lazy loading of data from HDF5 files, with optional caching
+ to improve performance during training. It can be used for both fine-tuning (with labels)
+ and inference (without labels) modes. The class handles data preprocessing, such as
+ converting to tensors and optional unsqueezing.
+ Attributes:
+ hdf5_file (str): Path to the HDF5 file containing the dataset.
+ unsqueeze (bool): Whether to add an extra dimension to the input data (default: False).
+ finetune (bool): If True, loads both data and labels; if False, loads only data (default: True).
+ cache_size (int): Maximum number of samples to cache in memory (default: 1500).
+ use_cache (bool): Whether to use caching for faster access (default: True).
+ regression (bool): If True, treats labels as regression targets (float); else, classification (long) (default: False).
+ num_samples (int): Total number of samples in the dataset, determined from HDF5 file.
+ data (h5py.File or None): Handle to the opened HDF5 file (lazy-loaded).
+ X_ds (h5py.Dataset or None): Dataset handle for input data.
+ Y_ds (h5py.Dataset or None): Dataset handle for labels (if finetune is True).
+ cache (dict): Dictionary for caching data items (if use_cache is True).
+ cache_queue (deque): Queue to track the order of cached items for LRU eviction.
+ Note:
+ - The HDF5 file is expected to have 'data' and 'label' datasets.
+ - Caching uses an LRU (Least Recently Used) eviction policy.
+ - Suitable for use with PyTorch DataLoader for batched loading.
+ """
+
+ def __init__(
+ self,
+ hdf5_file: str,
+ unsqueeze: bool = False,
+ finetune: bool = True,
+ cache_size: int = 1500,
+ use_cache: bool = True,
+ regression: bool = False,
+ ):
+ self.hdf5_file = hdf5_file
+ self.unsqueeze = unsqueeze
+ self.cache_size = cache_size
+ self.finetune = finetune
+ self.use_cache = use_cache
+ self.regression = regression
+
+ self.data = None
+ self.X_ds = None
+ self.Y_ds = None
+
+ # Open once to get length, then close immediately
+ with h5py.File(self.hdf5_file, "r") as f:
+ self.num_samples = f["data"].shape[0]
+
+ if self.use_cache:
+ self.cache: dict[int, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+ self.cache_queue = deque()
+
+ def _open_file(self) -> None:
+ # 'rdcc_nbytes' to increase the raw data chunk cache size
+ self.data = h5py.File(self.hdf5_file, "r", rdcc_nbytes=1024 * 1024 * 4)
+ if self.data is not None:
+ self.X_ds = self.data["data"]
+ self.Y_ds = self.data["label"]
+
+ def __len__(self) -> int:
+ return self.num_samples
+
+ def __getitem__(self, index):
+ # Check Cache
+ if self.use_cache and index in self.cache:
+ return self._process_data(self.cache[index])
+
+ # Open file (Lazy Loading for Multiprocessing)
+ if self.data is None:
+ self._open_file()
+
+ # Read Data, HDF5 slicing returns numpy array
+ X_np = self.X_ds[index]
+ X = torch.from_numpy(X_np).float()
+
+ if self.finetune:
+ Y_np = self.Y_ds[index]
+ if self.regression:
+ Y = torch.from_numpy(Y_np).float()
+ else:
+ # Ensure scalar is converted properly
+ Y = torch.tensor(Y_np, dtype=torch.long)
+
+ data_item = (X, Y)
+ else:
+ data_item = X
+
+ # Update Cache
+ if self.use_cache:
+ # If cache is full, remove oldest item from dict AND queue
+ if len(self.cache) >= self.cache_size:
+ oldest_index = self.cache_queue.popleft()
+ del self.cache[oldest_index]
+
+ self.cache[index] = data_item
+ self.cache_queue.append(index)
+
+ return self._process_data(data_item)
+
+ def _process_data(self, data_item):
+ """Helper to handle squeezing/returning uniformly."""
+ if self.finetune:
+ X, Y = data_item
+ else:
+ X = data_item
+ Y = None
+
+ if self.unsqueeze:
+ X = X.unsqueeze(0)
+
+ if self.finetune:
+ return X, Y
+ else:
+ return X
+
+ def __del__(self):
+ if self.data is not None:
+ self.data.close()
diff --git a/datasets/emg_pretrain_dataset.py b/datasets/emg_pretrain_dataset.py
new file mode 100644
index 0000000..15a3248
--- /dev/null
+++ b/datasets/emg_pretrain_dataset.py
@@ -0,0 +1,136 @@
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
+import os
+import threading
+from collections import deque
+from typing import Optional
+
+import h5py
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+
+# thread-local storage for per-worker file handle
+_thread_local = threading.local()
+
+
+def _get_h5_handle(path):
+ h5f = getattr(_thread_local, "h5f", None)
+ if h5f is None or h5f.filename != path:
+ h5f = h5py.File(path, "r")
+ _thread_local.h5f = h5f
+ return h5f
+
+
+class EMGPretrainDataset(Dataset):
+ """
+ A PyTorch Dataset class for loading EMG (electromyography) data from HDF5 files for pretraining purposes.
+ This dataset discovers all .h5 files in the specified directory, builds an index of samples across all files,
+ and provides access to individual samples. It supports optional caching to improve performance, channel padding,
+ and squeezing of the data tensor.
+ Args:
+ data_dir (str): Path to the directory containing .h5 files.
+ squeeze (bool, optional): Whether to squeeze the data tensor. Defaults to False.
+ cache_size (int, optional): Size of the cache. Defaults to 1500.
+ use_cache (bool, optional): Enable caching. Defaults to True.
+ pad_up_to_max_chans (int | None, optional): Number of channels to pad to. Defaults to None.
+ max_samples (int | None, optional): Limit the total number of samples. Defaults to None.
+ Raises:
+ RuntimeError: If no .h5 files are found in the data directory.
+ Note:
+ The .h5 files are expected to have a 'data' dataset with shape (N, C, T), where N is the number of samples,
+ C is the number of channels, and T is the number of time points.
+ Caching uses a simple LRU mechanism with a deque to track access order.
+ The __del__ method ensures that any open HDF5 file handles are closed upon deletion.
+ """
+
+ def __init__(
+ self,
+ data_dir: str,
+ squeeze: bool = False,
+ cache_size: int = 1500,
+ use_cache: bool = True,
+ pad_up_to_max_chans: Optional[int] = None,
+ max_samples: Optional[int] = None,
+ ):
+ super().__init__()
+ self.squeeze = squeeze
+ self.cache_size = cache_size
+ self.use_cache = use_cache
+ self.pad_up_to_max_chans = pad_up_to_max_chans
+
+ # discover all .h5 files
+ self.file_paths = sorted(os.path.join(data_dir, fn) for fn in os.listdir(data_dir) if fn.endswith(".h5"))
+ if not self.file_paths:
+ raise RuntimeError(f"No .h5 files in {data_dir!r}")
+
+ # build index of (file_path, sample_idx)
+ self.index_map = []
+ for fp in self.file_paths:
+ with h5py.File(fp, "r") as h5f:
+ n = h5f["data"].shape[0]
+ for i in range(n):
+ self.index_map.append((fp, i))
+ if max_samples is not None:
+ self.index_map = self.index_map[:max_samples]
+
+ # Cache to store recently accessed samples
+ if use_cache:
+ self.cache = {}
+ self.cache_queue = deque(maxlen=self.cache_size)
+
+ def __len__(self):
+ return len(self.index_map)
+
+ def __getitem__(self, index):
+ if self.use_cache and index in self.cache:
+ cached_data = self.cache[index]
+ X = cached_data
+ else:
+ fp, local_idx = self.index_map[index]
+ h5f = _get_h5_handle(fp)
+ np_x = h5f["data"][local_idx] # shape (C, T)
+ X = torch.from_numpy(np_x).float()
+
+ if self.use_cache:
+ # If cache is full, remove oldest item from dict AND queue
+ if len(self.cache) >= self.cache_size:
+ oldest_index = self.cache_queue.popleft()
+ del self.cache[oldest_index]
+
+ self.cache[index] = X
+ self.cache_queue.append(index)
+
+ # squeeze if requested
+ if self.squeeze:
+ X = X.unsqueeze(0)
+
+ # pad channels if requested
+ if self.pad_up_to_max_chans is not None:
+ C = X.shape[0]
+ to_pad = self.pad_up_to_max_chans - C
+ if to_pad > 0:
+ X = F.pad(X, (0, 0, 0, to_pad)) # (channels, time) -> pad channels
+
+ return X
+
+ def __del__(self):
+ h5f = getattr(_thread_local, "h5f", None)
+ if h5f is not None:
+ h5f.close()
diff --git a/docs/model/TinyMyo.md b/docs/model/TinyMyo.md
new file mode 100644
index 0000000..35a332d
--- /dev/null
+++ b/docs/model/TinyMyo.md
@@ -0,0 +1,249 @@
+# **TinyMyo: A Tiny Foundation Model for EMG Signals**
+
+**TinyMyo** is a lightweight **3.6M-parameter** Transformer-based foundation model (FM) for **surface EMG (sEMG)**. It is designed for **broad generalization** across datasets, sensor configurations, domains, and tasks, while remaining efficient enough for **ultra-low-power edge deployment** on microcontrollers.
+
+TinyMyo is the **first EMG foundation model** demonstrated on a microcontroller (GAP9), achieving an inference time of **0.785 s**, energy of **44.91 mJ**and power envelope of **57.18 mW**.
+
+---
+
+## **1. Default Input Assumptions**
+
+Unless otherwise specified, TinyMyo uses:
+
+* **Channels**: 16
+* **Sampling Rate**: 2000 Hz
+* **Segment Length**: 1000 samples (0.5 s)
+* **Windowing**: 50% overlap during pretraining
+* **Preprocessing**:
+
+ * 4th-order **20–450 Hz bandpass**
+ * **Notch filter** at 50 Hz
+ * Per-channel min–max normalization (pretraining)
+ * Per-channel z-score normalization (downstream)
+
+Datasets with fewer than 16 channels are *zero-padded* only during pretraining.
+
+---
+
+## **2. Pretraining Overview**
+
+TinyMyo is pretrained using **masked reconstruction** across three heterogeneous large-scale EMG datasets:
+
+| Dataset | Subjects | fs | Channels | Size |
+| ----------- | -------- | ------- | -------- | ------- |
+| Ninapro DB6 | 10 | 2000 Hz | 14 | 20.3 GB |
+| Ninapro DB7 | 22 | 2000 Hz | 12 | 30.9 GB |
+| EMG2Pose | 192 | 2000 Hz | 16 | 431 GB |
+
+### **Tokenization: Channel-Independent Patches**
+
+Unlike 2D (channel-mixing) tokenizers in EEG FMs, TinyMyo uses **strictly per-channel patching**:
+
+* Patch length: **20 samples**
+* Patch stride: **20 samples**
+* Tokens per channel: **50**
+* Sequence length: **800 tokens** (16 x 50)
+* Positional encoding: **RoPE** (Rotary Position Embeddings)
+
+This preserves electrode-specific information while letting attention learn cross-channel relationships.
+
+### **Transformer Encoder**
+
+* **8 layers**
+* **3 heads**
+* Embedding dim: **192**
+* Pre-LayerNorm
+* Dropout & drop-path: **0.1**
+
+### **Lightweight Decoder**
+
+A simple **linear layer** (≈ **3.9k params**) reconstructs masked patches.
+Following SimMIM philosophy, the minimal decoder forces the encoder to learn structured latent representations.
+
+### **Masking Objective**
+
+* **50% random masking** with a learnable [MASK] token
+* Reconstruction loss = **Smooth L1**
+
+$$
+ \mathcal{L} = \mathcal{L}*{\text{masked}} + 0.1 \cdot \mathcal{L}*{\text{visible}}
+$$
+
+### **Training Setup**
+
+* Optimizer: **AdamW** (β=(0.9, 0.98), wd=0.01)
+* LR: **1x10⁻⁴**, cosine decay
+* Batch size: **512** with gradient accumulation
+* Epochs: **50** with 10-epoch warm-up
+* Hardware: **4x NVIDIA GH200 GPUs**
+
+---
+
+## **3. Architecture Summary**
+
+### **Model Variant**
+
+| Variant | Params | (Layers, dim) |
+| ----------- | -------- | ------------- |
+| **TinyMyo** | **3.6M** | (8, 192) |
+
+### **Pipeline**
+
+**Pretraining**
+
+```
+EMG -> Channel-indep. patching -> Masking -> Transformer Encoder -> Linear Decoder -> Patch reconstruction
+```
+
+**Downstream**
+
+```
+EMG -> Patching -> Transformer Encoder -> Channel fusion -> Temporal pooling -> Task-specific head
+```
+
+---
+
+## **4. Downstream Tasks**
+
+TinyMyo supports three major categories:
+
+---
+
+### **4.1 Hand Gesture Classification**
+
+Evaluated on:
+
+* **Ninapro DB5** (52 classes, 10 subjects)
+* **EPN-612** (5 classes, 612 subjects)
+* **UCI EMG** (6 classes, 36 subjects)
+* **Generic Neuromotor Interface** (Meta wristband; 9 gestures)
+ * Repository: [MatteoFasulo/generic-neuromotor-interface](https://github.com/MatteoFasulo/generic-neuromotor-interface)
+
+>Note: Additional details on generic non-invasive neuromotor interface dataset and instructions on how to run experiments can be found in the linked repository inside the `notebooks` folder.
+
+**Pipeline**
+
+* EMG filtering: **20–90 Hz** bandpass + 50 Hz notch
+* Windows:
+
+ * **200 ms** (best for DB5)
+ * **1000 ms** (best for EPN & UCI)
+* Per-channel z-scoring
+* Linear classification head
+
+ * Input: **C x 192**
+ * Params: typically **<40k**
+
+**Performance (Fine-tuned)**
+
+| Dataset | Metric | Result |
+| ------------------------ | -------- | ----------------- |
+| **Ninapro DB5 (200 ms)** | Accuracy | **89.41 ± 0.16%** |
+| **EPN-612 (1000 ms)** | Accuracy | **96.74 ± 0.09%** |
+| **UCI EMG (1000 ms)** | Accuracy | **97.56 ± 0.32%** |
+| **Neuromotor Interface** | CLER | **0.153 ± 0.006** |
+
+TinyMyo achieves **state-of-the-art** on DB5, EPN-612, and UCI.
+
+---
+
+### **4.2 Hand Kinematic Regression**
+
+Dataset: **Ninapro DB8**
+Task: Regress **5 joint angles (DoA)**
+Preprocessing: z-score only; windows of **200 ms** or **1000 ms**
+
+**Regression head (788k params)**
+
+* Depthwise + pointwise convolutions
+* Upsampling
+* Global average pooling
+* Linear projection to 5 outputs
+
+**Performance (Fine-tuned)**
+
+* **MAE = 8.77 ± 0.12°** (1000 ms window)
+
+Although previous works achieve lower MAE (≈6.89°), those models are **subject-specific**, whereas TinyMyo trains **one model across all subjects**, a significantly harder problem.
+
+---
+
+### **4.3 Speech Production & Speech Recognition**
+
+Dataset: **Gaddy Silent Speech**
+(8 channels, 1000 Hz, face/neck EMG)
+Repository: [MatteoFasulo/silent_speech](https://github.com/MatteoFasulo/silent_speech)
+>Note: Additional details on Silent Speech dataset and instructions on how to run experiments can be found in the linked repository.
+
+#### **Speech Production (EMG -> MFCC -> HiFi-GAN -> Audio)**
+
+Pipeline:
+
+1. Residual downsampling blocks
+2. TinyMyo encoder
+3. Linear projection to **26-dim MFCC**
+4. HiFi-GAN vocoder (pretrained)
+
+**WER (Fine-tuned):**
+
+* **33.54 ± 1.12%**
+
+Comparable to SoA (≈32%) with **>90% fewer parameters** in the transduction model.
+
+#### **Speech Recognition (EMG -> Text)**
+
+* Same encoder + residual front-end
+* Linear projection to 37 characters
+* **CTC loss**
+* 4-gram LM + beam search
+
+**WER:**
+
+* **33.95 ± 0.97%**
+
+Although not surpassing the multimodal MONA-LISA (12.2%), TinyMyo is vastly smaller and EMG-only.
+
+---
+
+## **5. Edge Deployment**
+
+TinyMyo is deployed on **GAP9 (RISC-V, ultra-low power)**.
+
+Key elements:
+
+* **INT8 quantization**, including attention
+* Hierarchical streaming:
+
+ * L3 -> L2 (slab streaming)
+ * L2 -> L1 (tile streaming)
+* Integer softmax, integer LayerNorm, integer GELU
+* Static liveness-based memory arena
+
+**Runtime (NinaPro EPN612 pipeline):**
+
+* **0.785 s inference time**
+* **44.91 mJ energy**
+* **57.18 mW average power**
+
+This is the **first demonstration of an EMG FM on a microcontroller**.
+
+---
+
+## **6. Results Summary**
+
+### **Pretraining**
+
+* Smooth L1 reconstruction with high fidelity
+* Total FLOPs: ~4.0G
+
+### **Downstream SoA Highlights**
+
+* **DB5:** 89.41%
+* **EPN-612:** 96.74%
+* **UCI EMG:** 97.56%
+* **Neuromotor:** 0.153 CLER
+* **DB8 Regression:** MAE 8.77°
+* **Speech Production:** WER 33.54%
+* **Speech Recognition:** WER 33.95%
+
+Overall TinyMyo matches or exceeds state-of-the-art while being on par with or smaller than prior EMG foundation models.
diff --git a/docs/model/logo/TinyMyo_logo.png b/docs/model/logo/TinyMyo_logo.png
new file mode 100644
index 0000000..6cff2f0
Binary files /dev/null and b/docs/model/logo/TinyMyo_logo.png differ
diff --git a/models/README.md b/models/README.md
index 434cead..f56b85c 100644
--- a/models/README.md
+++ b/models/README.md
@@ -1,7 +1,11 @@
Copyright (C) 2025 ETH Zurich, Switzerland. SPDX-License-Identifier: Apache-2.0. See LICENSE file at the root of the repository for details.
# Models
+
This directory contains the implementations of the deep learning models used in the **BioFoundation** project. Each model is defined as a PyTorch `nn.Module` and is designed to be configurable and extensible for various research tasks.
+
## Available Models
+
- **FEMBA**: A lightweight EEG model designed for both pretraining and fine-tuning tasks. For a more detailed description of the model check the [documentation](../docs/model/FEMBA.md).
-- **LUNA**: An efficient EEG model specifically designed for handling different types of electrode configurations. For a more detailed description of the model check the [documentation](../docs/model/LUNA.md).
\ No newline at end of file
+- **LUNA**: An efficient EEG model specifically designed for handling different types of electrode configurations. For a more detailed description of the model check the [documentation](../docs/model/LUNA.md).
+- **TinyMyo**: A 3.6M-parameter Transformer-based foundation model for surface EMG (sEMG). It is pretrained on >480 GB of EMG data and optimized for ultra-low-power, real-time deployment, including microcontrollers (GAP9) where it achieves an inference time of 0.785 s, energy of 44.91 mJ and power envelope of 57.18 mW. For a more detailed description of the model check the [documentation](../docs/model/TinyMyo.md).
diff --git a/models/TinyMyo.py b/models/TinyMyo.py
new file mode 100644
index 0000000..5e15687
--- /dev/null
+++ b/models/TinyMyo.py
@@ -0,0 +1,764 @@
+import math
+from dataclasses import dataclass
+from typing import Literal, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from timm.layers.drop import DropPath
+from timm.layers.mlp import Mlp
+from timm.layers.weight_init import trunc_normal_ as __call_trunc_normal_
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0):
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
+
+
+# https://docs.pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
+@dataclass(eq=False)
+class RotaryPositionalEmbeddings(nn.Module):
+ """
+ This class implements Rotary Positional Embeddings (RoPE)
+ proposed in https://arxiv.org/abs/2104.09864.
+
+ Reference implementation (used for correctness verification)
+ can be found here:
+ https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
+
+ In this implementation we cache the embeddings for each position upto
+ ``max_seq_len`` by computing this during init.
+
+ Args:
+ dim (int): Embedding dimension. This is usually set to the dim of each
+ head in the attention module computed as ``embed_dim // num_heads``
+ max_seq_len (int): Maximum expected sequence length for the
+ model, if exceeded the cached freqs will be recomputed
+ base (int): The base for the geometric progression used to compute
+ the rotation angles
+ """
+
+ dim: int
+ max_seq_len: int = 4096
+ base: int = 10_000
+
+ def __post_init__(self):
+ super().__init__()
+ self.rope_init()
+
+ def rope_init(self):
+ """Initialize the RoPE embeddings and cache."""
+ # ensure dim is int
+ dim = int(self.dim)
+ theta = 1.0 / (
+ self.base
+ ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim)
+ )
+ self.register_buffer("theta", theta, persistent=False)
+ self.build_rope_cache(self.max_seq_len)
+
+ def build_rope_cache(self, max_seq_len: int = 4096) -> None:
+ """Build the RoPE cache for positions up to max_seq_len."""
+ # Create position indexes `[0, 1, ..., max_seq_len - 1]`
+ seq_idx = torch.arange(0, max_seq_len, dtype=torch.float32) # type: ignore
+
+ # Outer product of theta and position index; output tensor has
+ # a shape of [max_seq_len, dim // 2]
+ idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
+
+ # cache includes both the cos and sin components and so the output shape is
+ # [max_seq_len, dim // 2, 2]
+ cache: torch.Tensor = torch.stack(
+ [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1
+ )
+ self.register_buffer("cache", cache, persistent=False)
+
+ def forward(
+ self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape
+ ``[b, s, n_h, h_d]``
+ input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
+ of each token. During training, this is used to indicate the positions
+ of each token relative to its sample when packed, shape [b, s].
+ During inference, this indicates the position of the current token.
+ If none, assume the index of the token is its position id. Default is None.
+
+ Returns:
+ torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]``
+
+ Notation used for tensor shapes:
+ - b: batch size
+ - s: sequence length
+ - n_h: num heads
+ - h_d: head dim
+ """
+ # input tensor has shape [b, s, n_h, h_d]
+ seq_len = x.size(1)
+
+ # extract the values based on whether input_pos is set or not
+ if input_pos is None:
+ rope_cache = self.cache[:seq_len] # type: ignore
+ else:
+ input_pos = input_pos.to(torch.long)
+ rope_cache = self.cache[input_pos] # type: ignore
+
+ # reshape input; the last dimension is used for computing the output.
+ # Cast to float to match the reference implementation
+ # tensor has shape [b, s, n_h, h_d // 2, 2]
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+
+ # reshape the cache for broadcasting
+ # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples,
+ # otherwise has shape [1, s, 1, h_d // 2, 2]
+ rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
+
+ # tensor has shape [b, s, n_h, h_d // 2, 2]
+ x_out = torch.stack(
+ [
+ xshaped[..., 0] * rope_cache[..., 0]
+ - xshaped[..., 1] * rope_cache[..., 1],
+ xshaped[..., 1] * rope_cache[..., 0]
+ + xshaped[..., 0] * rope_cache[..., 1],
+ ],
+ -1,
+ )
+
+ # tensor has shape [b, s, n_h, h_d]
+ x_out = x_out.flatten(3)
+ return x_out.type_as(x)
+
+
+@dataclass(eq=False)
+class PatchEmbedWaveformKeepChans(nn.Module):
+ """
+ Patch embedding layer for waveform data that keeps channel information.
+
+ This module embeds patches from waveform inputs while preserving the channel dimension.
+ It uses a 2D convolution to project patches into an embedding space, and rearranges
+ the output to flatten patches across channels and time.
+
+ Args:
+ img_size (int): The size of the input waveform in the time dimension.
+ patch_size (int): The size of each patch in the time dimension.
+ in_chans (int): Number of input channels.
+ embed_dim (int): Dimensionality of the embedding space.
+ """
+
+ img_size: int
+ patch_size: int
+ in_chans: int
+ embed_dim: int
+
+ def __post_init__(self):
+ super().__init__()
+ self.num_patches = (self.img_size // self.patch_size) * self.in_chans
+ self.proj = nn.Conv2d(
+ 1,
+ self.embed_dim,
+ kernel_size=(1, self.patch_size),
+ stride=(1, self.patch_size),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: Input waveform tensor of shape (B, C, T)
+ Returns:
+ Patch embeddings of shape (B, N, D) where N is number of patches and D is embed_dim
+ """
+ # shared projection layer across channels
+ x = self.proj(x.unsqueeze(1))
+ x = rearrange(x, "B D C t -> B (C t) D")
+ return x
+
+
+@dataclass(eq=False)
+class PatchingModule(nn.Module):
+ """Image to Patch Embedding of choice according to the parameters given."""
+
+ img_size: int
+ patch_size: int
+ in_chans: int
+ embed_dim: int
+
+ def __post_init__(self):
+ super().__init__()
+ self.patch_embed = PatchEmbedWaveformKeepChans(
+ self.img_size, self.patch_size, self.in_chans, self.embed_dim
+ )
+ self.num_patches = self.patch_embed.num_patches
+ self.init_patch_embed()
+
+ def init_patch_embed(self):
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.patch_embed.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.patch_embed(x)
+
+
+@dataclass(eq=False)
+class RotarySelfAttentionBlock(nn.Module):
+ """
+ A self-attention block that incorporates rotary positional embeddings (RoPE) for enhanced positional awareness.
+
+ This module implements multi-head self-attention with rotary positional embeddings applied to query and key tensors,
+ followed by scaled dot-product attention. It is designed for transformer-based architectures, particularly in vision
+ or sequence modeling tasks where positional information is crucial.
+
+ Attributes:
+ dim (int): The dimensionality of the input and output features.
+ num_heads (int): The number of attention heads.
+ rope (RotaryPositionalEmbeddings): The rotary positional embedding module.
+ scale (float): The scaling factor for attention logits.
+ qkv (nn.Linear): Linear layer for projecting input to query, key, and value.
+ attn_drop_fn (nn.Dropout): Dropout layer for attention weights.
+ proj (nn.Linear): Linear layer for projecting attention output.
+ proj_drop (nn.Dropout): Dropout layer for projection output.
+ """
+
+ dim: int
+ num_heads: int = 8
+ qkv_bias: bool = False
+ attn_drop: float = 0.0
+ proj_drop: float = 0.0
+
+ def __post_init__(self):
+ super().__init__()
+ head_dim = self.dim // self.num_heads
+ self.rope = RotaryPositionalEmbeddings(
+ dim=head_dim,
+ max_seq_len=1024,
+ base=10_000,
+ )
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=self.qkv_bias)
+ self.attn_drop_fn = nn.Dropout(self.attn_drop)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.p_drop = nn.Dropout(self.proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass for rotary self-attention block."""
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ ) # (K, B, H, N, D)
+ q, k, v = qkv.unbind(0) # each: (B, H, N, D)
+
+ q = self.rope(q)
+ k = self.rope(k)
+
+ # pylint: disable=not-callable
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop if self.training else 0.0,
+ is_causal=False,
+ enable_gqa=False,
+ )
+
+ x = x.transpose(2, 1).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.p_drop(x)
+ return x
+
+
+@dataclass(eq=False)
+class RotaryTransformerBlock(nn.Module):
+ """
+ A transformer block that incorporates rotary self-attention for enhanced positional encoding.
+
+ This block applies layer normalization, rotary self-attention, and a multi-layer perceptron (MLP)
+ in sequence, with drop paths for regularization. It follows a standard transformer
+ architecture but uses rotary embeddings to improve handling of sequential data.
+
+ Args:
+ dim (int): The dimensionality of the input and output features.
+ num_heads (int): The number of attention heads in the self-attention mechanism.
+ mlp_ratio (float, optional): The ratio of hidden features in the MLP to the input dimension. Defaults to 4.0.
+ qkv_bias (bool, optional): Whether to include bias terms in the query, key, and value projections. Defaults to False.
+ drop (float, optional): Dropout rate for the MLP and attention projections. Defaults to 0.0.
+ attn_drop (float, optional): Dropout rate specifically for the attention weights. Defaults to 0.0.
+ drop_path (float, optional): Drop path rate for stochastic depth regularization. Defaults to 0.0.
+ """
+
+ dim: int
+ num_heads: int
+ mlp_ratio: float = 4.0
+ qkv_bias: bool = False
+ drop: float = 0.0
+ attn_drop: float = 0.0
+ drop_path: float = 0.0
+ norm_layer = nn.LayerNorm
+
+ def __post_init__(self):
+ super().__init__()
+ self.norm1 = self.norm_layer(self.dim)
+ self.attn = RotarySelfAttentionBlock(
+ dim=self.dim,
+ num_heads=self.num_heads,
+ qkv_bias=self.qkv_bias,
+ attn_drop=self.attn_drop,
+ proj_drop=self.drop,
+ )
+ self.drop_path1 = (
+ DropPath(self.drop_path) if self.drop_path > 0.0 else nn.Identity()
+ )
+ self.drop_path2 = (
+ DropPath(self.drop_path) if self.drop_path > 0.0 else nn.Identity()
+ )
+ self.norm2 = self.norm_layer(self.dim)
+ self.mlp = Mlp(
+ in_features=self.dim,
+ hidden_features=int(self.dim * self.mlp_ratio),
+ act_layer=nn.GELU,
+ drop=self.drop,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x + self.drop_path1(self.attn(self.norm1(x)))
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
+ return x
+
+
+@dataclass(eq=False)
+class PatchReconstructionHead(nn.Module):
+ """
+ A neural network module for reconstructing image patches from token embeddings.
+
+ This head takes token embeddings as input and projects them back to the pixel space
+ for patch reconstruction. It is designed for use in vision transformer models where
+ patches are embedded and then reconstructed.
+
+ img_size (int): The size of the input image.
+ patch_size (int): The size of each patch.
+ in_chans (int): Number of input channels.
+ embed_dim (int): Dimensionality of the embedding space.
+
+ Attributes:
+ in_chans (int): Number of input channels.
+ img_size (int): The size of the input image.
+ patch_size (int): The size of each patch.
+ embed_dim (int): Dimensionality of the embedding space.
+ reconstruction_shape (int): Shape of the reconstructed patch, equal to patch_size.
+ decoder_pred (nn.Linear): Linear layer for projecting embeddings to pixel space.
+ """
+
+ img_size: int
+ patch_size: int
+ in_chans: int
+ embed_dim: int
+
+ def __post_init__(self):
+ super().__init__()
+ self.reconstruction_shape = self.patch_size
+
+ # Projection from embed space to pixel space
+ self.decoder_pred = nn.Linear(self.embed_dim, self.reconstruction_shape)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ No cls token is expected
+ Args:
+ x: [B, num_tokens, embed_dim] - token embeddings
+ """
+ x = self.decoder_pred(x)
+ return x
+
+
+@dataclass(eq=False)
+class EMGClassificationHead(nn.Module):
+ """
+ A classification head for EMG (Electromyography) data processing, designed to classify token embeddings into a specified number of classes.
+
+ This module takes token embeddings as input, applies a reduction strategy (either mean or concatenation across channels),
+ averages across patches, and then uses a linear classifier to produce logits for classification.
+
+ embed_dim (int): Dimensionality of the token embeddings.
+ num_classes (int): Number of output classes for classification.
+ in_chans (int): Number of input channels (e.g., EMG channels).
+ reduction (str): Reduction strategy for combining channel features. Options are "mean" or "concat".
+ - "mean": Averages across channels, resulting in feature dimension of embed_dim.
+ - "concat": Concatenates across channels, resulting in feature dimension of in_chans * embed_dim. Defaults to "concat".
+
+ Attributes:
+ classifier (nn.Linear): Linear layer for final classification, mapping from reduced feature dimension to num_classes.
+ """
+
+ embed_dim: int
+ num_classes: int
+ in_chans: int
+ reduction: Literal["mean", "concat"] = "concat"
+
+ def __post_init__(self):
+ super().__init__()
+ # after reduction, feature_dim to either embed_dim or in_chans*embed_dim
+ feat_dim = (
+ self.embed_dim
+ if self.reduction == "mean"
+ else self.in_chans * self.embed_dim
+ )
+
+ self.classifier = nn.Linear(feat_dim, self.num_classes)
+
+ # init weights
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m: nn.Module):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: token embeddings, shape (B, num_tokens, embed_dim)
+ Returns:
+ logits: (B, num_classes)
+ """
+ _, N, _ = x.shape
+ num_patches = N // self.in_chans
+
+ if self.reduction == "mean":
+ # Reshape to (B, in_chans, num_patches, embed_dim)
+ x = rearrange(x, "b (c p) d -> b c p d", c=self.in_chans, p=num_patches)
+ # Take mean across the channels (in_chans)
+ x = x.mean(dim=1) # (B, num_patches, embed_dim)
+ elif self.reduction == "concat":
+ # Reshape to (B, num_patches, embed_dim * in_chans)
+ x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans, p=num_patches)
+ else:
+ raise ValueError(f"Unknown reduction type: {self.reduction}")
+
+ # average across patches
+ x = x.mean(dim=1) # (B, feat_dim)
+
+ # apply projection to get logits
+ logits = self.classifier(x)
+ return logits
+
+
+@dataclass(eq=False)
+class EMGRegressionHead(nn.Module):
+ """
+ A regression head for EMG (Electromyography) signals using convolutional layers.
+
+ This module processes embedded features from a transformer model to perform
+ regression, predicting output signals of a specified dimension and length. It supports
+ different reduction methods for combining channel and patch features, followed by
+ convolutional layers for regression, and upsampling to a target sequence length.
+
+ Args:
+ in_chans (int): Number of input channels (e.g., EMG channels).
+ embed_dim (int): Dimension of the input embeddings.
+ output_dim (int): Dimension of the output regression targets.
+ reduction (str): Method to reduce features across channels.
+ "mean" averages embeddings, "concat" concatenates them. Defaults to "concat".
+ hidden_dim (int): Hidden dimension for the convolutional layers. Defaults to 256.
+ dropout (float): Dropout probability applied after the first convolution. Defaults to 0.1.
+ target_length (int): Desired length of the output sequence. If the input length differs,
+ linear interpolation is used to upsample. Defaults to 500.
+
+ Attributes:
+ in_chans (int): Number of input channels.
+ embed_dim (int): Dimension of the embeddings.
+ output_dim (int): Dimension of the output.
+ reduction (str): Reduction method used.
+ dropout (float): Dropout rate.
+ target_length (int): Target output sequence length.
+ """
+
+ in_chans: int
+ embed_dim: int
+ output_dim: int
+ reduction: Literal["mean", "concat"] = "concat"
+ hidden_dim: int = 256
+ dropout: float = 0.1
+ target_length: int = 500
+
+ def __post_init__(self):
+ super().__init__()
+ feat_dim = (
+ self.embed_dim
+ if self.reduction == "mean"
+ else self.in_chans * self.embed_dim
+ )
+
+ self.regressor = nn.Sequential(
+ nn.Conv1d(feat_dim, self.hidden_dim, kernel_size=1),
+ nn.SiLU(),
+ nn.Dropout(self.dropout),
+ # depthwise 3x3 conv: groups=hidden_dim to hidden_dimx3 params
+ nn.Conv1d(
+ self.hidden_dim,
+ self.hidden_dim,
+ kernel_size=3,
+ padding=1,
+ groups=self.hidden_dim,
+ ),
+ nn.SiLU(),
+ # pointwise 1x1 back to output_dim
+ nn.Conv1d(self.hidden_dim, self.output_dim, kernel_size=1),
+ )
+
+ # Initialize weights
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the model.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape (B, num_tokens, token_dim), where B is batch size,
+ num_tokens is the number of tokens, and token_dim is the token dimension.
+
+ Returns:
+ torch.Tensor: Output tensor of shape (B, target_length, output_dim), where target_length
+ is the target sequence length and output_dim is the output dimension.
+ """
+ # x: (B, num_tokens, token_dim)
+ if self.reduction == "mean":
+ x = rearrange(x, "b (c p) d -> b p d", c=self.in_chans)
+ elif self.reduction == "concat":
+ x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans)
+ else:
+ raise ValueError(f"Unknown reduction type: {self.reduction}")
+
+ # conv head expects (B, C, L)
+ x = x.transpose(1, 2) # (B, feat_dim, num_patches)
+ x = self.regressor(x) # (B, output_dim, num_patches)
+
+ # now upsample to target length
+ if x.size(-1) != self.target_length:
+ x = F.interpolate(
+ x, size=self.target_length, mode="linear", align_corners=False
+ )
+
+ # x: (B, output_dim, target_length)
+ out = x.transpose(1, 2) # (B, target_length, output_dim)
+ return out
+
+
+@dataclass(eq=False)
+class TinyMyo(nn.Module):
+ """
+ TinyMyo is a bidirectional Transformer model based on the Vision Transformer (ViT) architecture, adapted for electromyography (EMG) signal processing.
+ It supports multiple tasks including pretraining (reconstruction), classification, and regression.
+
+ The model uses a patch-based embedding approach to process input signals, followed by a series of transformer blocks with rotary position embeddings.
+ It includes a masking token for pretraining tasks and different heads for various downstream tasks.
+
+ Args:
+ img_size (int, optional): The size of the input signal (temporal dimension). Defaults to 1000.
+ patch_size (int, optional): The size of each patch for embedding. Defaults to 20.
+ in_chans (int, optional): Number of input channels (e.g., EMG channels). Defaults to 16.
+ embed_dim (int, optional): Dimensionality of the embedding space. Defaults to 192.
+ n_layer (int, optional): Number of transformer layers. Defaults to 8.
+ n_head (int, optional): Number of attention heads. Defaults to 3.
+ mlp_ratio (int, optional): Ratio for expanding the MLP hidden dimension. Defaults to 4.
+ qkv_bias (bool, optional): Whether to include bias in QKV projections. Defaults to True.
+ attn_drop (float, optional): Dropout rate for attention. Defaults to 0.1.
+ proj_drop (float, optional): Dropout rate for projections. Defaults to 0.1.
+ drop_path (float, optional): Stochastic depth drop path rate. Defaults to 0.1.
+ norm_layer (nn.Module, optional): Normalization layer class. Defaults to nn.LayerNorm.
+ task (str, optional): Task type, one of "pretraining", "classification", or "regression". Defaults to "classification".
+ classification_type (str, optional): Type of classification (e.g., "ml" for multi-label). Defaults to "ml".
+ reduction_type (str, optional): Type of reduction to apply, either "mean" or "concat". Defaults to "concat".
+ num_classes (int, optional): Number of classes for classification or output dimension for regression. Defaults to 53.
+ reg_target_len (int, optional): Target length for regression output. Defaults to 500.
+ """
+
+ img_size: int = 1000
+ patch_size: int = 20
+ in_chans: int = 16
+ embed_dim: int = 192
+ n_layer: int = 8
+ n_head: int = 3
+ mlp_ratio: int = 4
+ qkv_bias: bool = True
+ attn_drop: float = 0.1
+ proj_drop: float = 0.1
+ drop_path: float = 0.1
+ norm_layer = nn.LayerNorm
+ task: Literal["pretraining", "classification", "regression"] = "classification"
+ classification_type: Literal["ml", "mc"] = "ml"
+ reduction_type: Literal["concat", "mean"] = "concat"
+ num_classes: int = 53
+ reg_target_len: int = 500
+
+ def __post_init__(self):
+ super().__init__()
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+ self.patch_embedding = PatchingModule(
+ img_size=self.img_size,
+ patch_size=self.patch_size,
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ )
+ self.num_patches = self.patch_embedding.num_patches
+
+ self.blocks = nn.ModuleList(
+ [
+ RotaryTransformerBlock(
+ dim=self.embed_dim,
+ num_heads=self.n_head,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ drop=self.proj_drop,
+ attn_drop=self.attn_drop,
+ drop_path=self.drop_path,
+ )
+ for _ in range(self.n_layer)
+ ]
+ )
+ self.norm = self.norm_layer(self.embed_dim)
+
+ if (
+ self.task == "pretraining" or self.num_classes == 0
+ ): # reconstruction (pre-training)
+ self.model_head = PatchReconstructionHead(
+ img_size=self.img_size,
+ patch_size=self.patch_size,
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ )
+ elif self.task == "classification" and self.num_classes > 0:
+ self.model_head = EMGClassificationHead(
+ embed_dim=self.embed_dim,
+ num_classes=self.num_classes,
+ in_chans=self.in_chans,
+ reduction=self.reduction_type,
+ )
+ elif self.task == "regression":
+ self.model_head = EMGRegressionHead(
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ output_dim=self.num_classes,
+ reduction=self.reduction_type,
+ target_length=self.reg_target_len,
+ )
+ else:
+ raise ValueError(f"Unknown task type {self.task}")
+ self.initialize_weights()
+
+ # Some checks
+ assert (
+ self.img_size % self.patch_size == 0
+ ), f"img_size ({self.img_size}) must be divisible by patch_size ({self.patch_size})"
+
+ def initialize_weights(self):
+ """Initializes the model weights."""
+ # Encodings Initializations code taken from the LaBraM paper
+ trunc_normal_(self.mask_token, std=0.02)
+
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ def _init_weights(self, m):
+ """Initializes the model weights."""
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ """
+ Rescales the weights of attention and MLP layers to improve training stability.
+
+ For each layer, weights are divided by sqrt(2 * layer_id).
+ """
+
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks, start=1):
+ attn_proj = getattr(getattr(layer, "attn", None), "proj", None)
+ if attn_proj is not None:
+ rescale(attn_proj.weight.data, layer_id)
+
+ mlp_fc2 = getattr(getattr(layer, "mlp", None), "fc2", None)
+ if mlp_fc2 is not None:
+ rescale(mlp_fc2.weight.data, layer_id)
+
+ def prepare_tokens(
+ self, x_signal: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Prepares input tokens by embedding patches and applying masking if provided.
+ Args:
+ x_signal (torch.Tensor): Input signal tensor of shape (B, C, T).
+ mask (Optional[torch.Tensor]): Optional mask tensor of shape (B, C, T) indicating which patches to mask.
+ Returns:
+ torch.Tensor: Prepared token embeddings of shape (B, N, D) where N is number of patches and D is embed_dim.
+ """
+ x_patched = self.patch_embedding(x_signal) # [B, N, D]
+ x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel
+ if mask is not None:
+ mask_tokens = self.mask_token.repeat(
+ x_masked.shape[0], x_masked.shape[1], 1
+ ) # (B, N, D) N = C * num_patches_per_channel
+ mask = rearrange(
+ mask, "B C (S P) -> B (C S) P", P=self.patch_size
+ ) # (B, C, T) -> (B, N, P)
+ mask = (
+ (mask.sum(dim=-1) > 0).unsqueeze(-1).float()
+ ) # (B, N, 1), since a patch is either fully masked or not
+ x_masked = torch.where(mask.bool(), mask_tokens, x_masked)
+ return x_masked
+
+ def forward(
+ self, x_signal: torch.Tensor, mask: Optional[torch.BoolTensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass of the TinyMyo model.
+
+ This method processes the input signal tensor through the transformer blocks,
+ applies normalization, and then either reconstructs the signal or performs
+ classification/regression based on the model's configuration.
+
+ Args:
+ x_signal (torch.Tensor): The input signal tensor of shape [B, C, T],
+ where B is batch size, C is number of channels, and T is the temporal dimension.
+ mask (Optional[torch.BoolTensor]): Optional boolean mask tensor for
+ masking certain tokens during processing. If None, no masking is applied.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - If num_classes == 0 (reconstruction mode): The reconstructed signal
+ tensor of shape [B, N, patch_size] and the original input signal tensor.
+ - Otherwise (classification/regression mode): The output tensor of shape
+ [B, Out] and the original input signal tensor.
+ """
+ x_original = x_signal.clone()
+ x = self.prepare_tokens(x_signal, mask=mask)
+
+ # forward pass through transformer blocks
+ for blk in self.blocks:
+ x = blk(x)
+ x_latent = self.norm(x) # [B, N, D]
+
+ if self.num_classes == 0: # reconstruction
+ x_reconstructed = self.model_head(x_latent) # [B, N, patch_size]
+ return x_reconstructed, x_original
+
+ else: # classification or regression
+ x_out = self.model_head(x_latent) # [B, Out]
+ return x_out, x_original
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..e8d315a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,51 @@
+[project]
+name = "BioFoundation"
+version = "0.1.0"
+description = "BioFoundation is a flexible and extensible codebase for deep learning with biological signals."
+readme = "README.md"
+requires-python = ">=3.11"
+dependencies = [
+ "einops",
+ "fvcore",
+ "omegaconf",
+ "hydra-core",
+ "scipy",
+ "h5py",
+ "python-dateutil",
+ "lightning",
+ "tqdm",
+ "pandas",
+ "mne",
+ "psutil",
+ "pyyaml",
+ "tensorboardX",
+ "torch>=2.4.0, <2.8.0",
+ "scikit-learn",
+ "tensorboard",
+ "tensorrt",
+ "torch_optimizer",
+ "tables",
+ "timm",
+ "torchaudio",
+ "torcheeg",
+ "torchmetrics",
+ "warmup-scheduler",
+ "safetensors",
+ "rotary-embedding-torch",
+]
+
+[[tool.uv.index]]
+name = "pytorch-cu126"
+url = "https://download.pytorch.org/whl/cu126"
+explicit = true
+
+[tool.uv.sources]
+warmup-scheduler = { git = "https://github.com/ildoonet/pytorch-gradual-warmup-lr.git" }
+
+[tool.black]
+line-length = 88
+target-version = ["py311"]
+
+[tool.isort]
+profile = "black"
+multi_line_output = 3
diff --git a/run_train.py b/run_train.py
index 4be7371..0431cfc 100644
--- a/run_train.py
+++ b/run_train.py
@@ -1,58 +1,68 @@
-#*----------------------------------------------------------------------------*
-#* Copyright (C) 2025 ETH Zurich, Switzerland *
-#* SPDX-License-Identifier: Apache-2.0 *
-#* *
-#* Licensed under the Apache License, Version 2.0 (the "License"); *
-#* you may not use this file except in compliance with the License. *
-#* You may obtain a copy of the License at *
-#* *
-#* http://www.apache.org/licenses/LICENSE-2.0 *
-#* *
-#* Unless required by applicable law or agreed to in writing, software *
-#* distributed under the License is distributed on an "AS IS" BASIS, *
-#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
-#* See the License for the specific language governing permissions and *
-#* limitations under the License. *
-#* *
-#* Author: Thorir Mar Ingolfsson *
-#* Author: Anna Tegon *
-#* Author: Berkay Döner *
-#*----------------------------------------------------------------------------*
-import os
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Thorir Mar Ingolfsson *
+# * Author: Anna Tegon *
+# * Author: Berkay Döner *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
import logging
+import os
import os.path as osp
+from datetime import datetime
from logging import Logger
+
import hydra
import pytorch_lightning as pl
+import torch
+import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
-from util.train_utils import find_last_checkpoint_path
-from pytorch_lightning import Trainer
-from pytorch_lightning.strategies import DDPStrategy
+from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
-from pytorch_lightning import seed_everything
-from datetime import datetime
+from pytorch_lightning.strategies import DDPStrategy
+from util.train_utils import find_last_checkpoint_path
-os.environ['DATA_PATH'] = "#CHANGEME"
-os.environ['CHECKPOINT_DIR'] = '#CHANGEME'
+for env_var in ["DATA_PATH", "CHECKPOINT_DIR"]:
+ env_var_value = os.environ.get(env_var)
+ if env_var_value is None or env_var_value == "#CHANGEME":
+ raise RuntimeError(f"Environment variable {env_var} is not set. Please set it before running the script.")
OmegaConf.register_new_resolver("env", lambda key: os.getenv(key))
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
logger: Logger = logging.getLogger(__name__)
+# Set float32 matmul precision to high for better performance on supported hardware
+torch.set_float32_matmul_precision("high")
+
+
def train(cfg: DictConfig):
seed_everything(cfg.seed)
-
- date_format = "%d_%m_%H-%M"
+ date_format = "%d_%m_%H-%M"
# Create your version_name
- version= f"{cfg.tag}_{datetime.now().strftime(date_format)}"
-
+ version = f"{cfg.tag}_{datetime.now().strftime(date_format)}"
+
# tensorboard
- tb_logger = TensorBoardLogger(save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=version)
+ tb_logger = TensorBoardLogger(
+ save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=version
+ )
# DataLoader
print("===> Loading datasets")
@@ -63,21 +73,21 @@ def train(cfg: DictConfig):
model = hydra.utils.instantiate(cfg.task, cfg)
print(model)
- safetensors_path = cfg.get('pretrained_safetensors_path', None)
- checkpoint_path = cfg.get('pretrained_checkpoint_path', None)
+ safetensors_path = cfg.get("pretrained_safetensors_path", None)
+ checkpoint_path = cfg.get("pretrained_checkpoint_path", None)
# Load pretrained checkpoint
if safetensors_path is not None:
print(f"===> Loading pretrained safetensors from {safetensors_path}")
# Assuming your model has this method
- model.load_safetensors_checkpoint(safetensors_path)
+ model.load_safetensors_checkpoint(safetensors_path)
elif checkpoint_path is not None:
print(f"===> Loading pretrained checkpoint from {checkpoint_path}")
model.load_pretrained_checkpoint(checkpoint_path)
else:
- print("No pretrained checkpoint provided. Proceeding without loading.")
+ print("No pretrained checkpoint provided. Proceeding without loading.")
- # New Checkpoint dipath
+ # New Checkpoint dirpath
checkpoint_dirpath = cfg.io.checkpoint_dirpath
checkpoint_dirpath = osp.join(checkpoint_dirpath, cfg.tag, version)
print(f"Checkpoint path: {checkpoint_dirpath}")
@@ -86,7 +96,9 @@ def train(cfg: DictConfig):
last_ckpt = find_last_checkpoint_path(checkpoint_dirpath)
print(f"last_ckpt_{last_ckpt}")
print("===> Checkpoint callbacks")
- model_checkpoint = ModelCheckpoint(dirpath=checkpoint_dirpath, **cfg.model_checkpoint)
+ model_checkpoint = ModelCheckpoint(
+ dirpath=checkpoint_dirpath, **cfg.model_checkpoint
+ )
model_summary = pl.callbacks.ModelSummary(max_depth=4)
callbacks = [model_checkpoint, model_summary]
@@ -113,26 +125,64 @@ def train(cfg: DictConfig):
)
# Train the model
+ results: dict = {}
if cfg.training:
print("===> Start training")
trainer.fit(model, data_module, ckpt_path=last_ckpt)
best_ckpt = model_checkpoint.best_model_path
-
+
print(f"Best checkpoint path: {best_ckpt}")
print(f"Best model score: {model_checkpoint.best_model_score}")
if cfg.final_validate:
print("===> Start validation")
- trainer.validate(model, data_module,ckpt_path=best_ckpt)
+ trainer.validate(model, data_module, ckpt_path=best_ckpt)
+
if cfg.final_test:
- print("===> Start testing")
- trainer.test(model, data_module, ckpt_path=last_ckpt)
+ # rank 0 only
+ # Validate and test run on 1 device only (i.e. no distributed data parallelism)
+ # This is to ensure reproducibility of metrics reported.
+
+ del data_module, trainer
+ print("Destroying process group...")
+ if dist.is_initialized():
+ dist.destroy_process_group()
+ print("Destroyed process group.")
+
+ if pl.utilities.rank_zero_only.rank == 0:
+ print("Re-instantiating LightningDataModule for evaluation...")
+ data_module = hydra.utils.instantiate(cfg.data_module)
+ results, trainer = _run_test(
+ module=model,
+ datamodule=data_module,
+ results=results,
+ accelerator=cfg.trainer.accelerator,
+ last_ckpt=best_ckpt,
+ )
if not cfg.training:
trainer.save_checkpoint(f"{checkpoint_dirpath}/last.ckpt")
+@pl.utilities.rank_zero_only
+def _run_test(
+ module: pl.LightningModule,
+ datamodule: pl.LightningDataModule,
+ results,
+ accelerator,
+ last_ckpt,
+):
+ trainer = pl.Trainer(
+ accelerator=accelerator,
+ devices=1,
+ )
+ print("===> Start testing")
+ test_results = trainer.test(module, datamodule=datamodule, ckpt_path=last_ckpt)
+ results["test_metrics"] = test_results
+ return results, trainer
+
+
@hydra.main(config_path="./config", config_name="defaults", version_base="1.1")
def run(cfg: DictConfig):
print(f"PyTorch-Lightning Version: {pl.__version__}")
@@ -142,5 +192,5 @@ def run(cfg: DictConfig):
if __name__ == "__main__":
# Ensure environment variables are set before Hydra processes the config
- os.environ['HYDRA_FULL_ERROR'] = '1'
- run()
\ No newline at end of file
+ os.environ["HYDRA_FULL_ERROR"] = "1"
+ run()
diff --git a/tasks/finetune_task_EMG.py b/tasks/finetune_task_EMG.py
new file mode 100644
index 0000000..8599443
--- /dev/null
+++ b/tasks/finetune_task_EMG.py
@@ -0,0 +1,339 @@
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
+from typing import Optional
+
+import hydra
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import torch_optimizer as torch_optim
+from omegaconf import DictConfig
+from safetensors.torch import load_file
+from torchmetrics import MetricCollection
+from torchmetrics.classification import (
+ AUROC,
+ Accuracy,
+ AveragePrecision,
+ CohenKappa,
+ F1Score,
+ Precision,
+ Recall,
+)
+
+from util.train_utils import MinMaxNormalization
+
+
+class FinetuneTask(pl.LightningModule):
+ """
+ PyTorch Lightning module for fine-tuning a classification model, with support for:
+
+ - Classification types:
+ - `bc`: Binary Classification
+ - `ml`: Multi-Label Classification
+
+ - Metric logging during training, validation, and testing, including accuracy, precision, recall, F1 score, AUROC, and more
+ - Optional input normalization with configurable normalization functions
+ - Custom optimizer support including SGD, Adam, AdamW, and LAMB
+ - Learning rate schedulers with configurable scheduling strategies
+ - Layer-wise learning rate decay for fine-grained learning rate control across model blocks
+ """
+
+ def __init__(self, hparams: DictConfig):
+ """
+ Initialize the FinetuneTask module.
+
+ Args:
+ hparams (DictConfig): Hyperparameters and configuration loaded via Hydra.
+ """
+ super().__init__()
+ self.save_hyperparameters(hparams)
+ self.model = hydra.utils.instantiate(self.hparams.model)
+ self.num_classes = self.hparams.model.num_classes
+ self.classification_type = self.hparams.model.classification_type
+
+ # Enable normalization if specified in parameters
+ self.normalize = False
+ if "input_normalization" in self.hparams and self.hparams.input_normalization.normalize:
+ self.normalize = True
+ self.normalize_fct = MinMaxNormalization()
+
+ # Loss function
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=0.10)
+
+ # Classification mode detection
+ if not isinstance(self.num_classes, int):
+ raise TypeError("Number of classes must be an integer.")
+ elif self.num_classes < 2:
+ raise ValueError("Number of classes must be at least 2.")
+ elif self.num_classes == 2:
+ self.classification_task = "binary"
+ else:
+ self.classification_task = "multiclass"
+
+ # Metrics
+ label_metrics = MetricCollection(
+ {
+ "micro_acc": Accuracy(
+ task=self.classification_task,
+ num_classes=self.num_classes,
+ average="micro",
+ ),
+ "macro_acc": Accuracy(
+ task=self.classification_task,
+ num_classes=self.num_classes,
+ average="macro",
+ ),
+ "recall": Recall(task="multiclass", num_classes=self.num_classes, average="macro"),
+ "precision": Precision(
+ task=self.classification_task,
+ num_classes=self.num_classes,
+ average="macro",
+ ),
+ "f1": F1Score(
+ task=self.classification_task,
+ num_classes=self.num_classes,
+ average="macro",
+ ),
+ "cohen_kappa": CohenKappa(task=self.classification_task, num_classes=self.num_classes),
+ }
+ )
+ logit_metrics = MetricCollection(
+ {
+ "auroc": AUROC(
+ task=self.classification_task,
+ num_classes=self.num_classes,
+ average="macro",
+ ),
+ "average_precision": AveragePrecision(
+ task=self.classification_task,
+ num_classes=self.num_classes,
+ average="macro",
+ ),
+ }
+ )
+ self.train_label_metrics = label_metrics.clone(prefix="train/")
+ self.val_label_metrics = label_metrics.clone(prefix="val/")
+ self.test_label_metrics = label_metrics.clone(prefix="test/")
+ self.train_logit_metrics = logit_metrics.clone(prefix="train/")
+ self.val_logit_metrics = logit_metrics.clone(prefix="val/")
+ self.test_logit_metrics = logit_metrics.clone(prefix="test/")
+
+ def load_pretrained_checkpoint(self, model_ckpt):
+ """
+ Load a pretrained model checkpoint and unfreeze specific layers for fine-tuning.
+ """
+ assert self.model.model_head is not None
+ print("Loading pretrained checkpoint from .ckpt file")
+ checkpoint = torch.load(model_ckpt, map_location="cpu", weights_only=False)
+ state_dict = checkpoint["state_dict"]
+ self.load_state_dict(state_dict, strict=False)
+ for name, param in self.model.named_parameters():
+ if self.hparams.finetuning.freeze_layers:
+ param.requires_grad = False
+ if "model_head" in name:
+ param.requires_grad = True # Unfreeze model head
+
+ print("Pretrained model ready.")
+
+ def load_safetensors_checkpoint(self, model_ckpt):
+ """
+ Load a pretrained model checkpoint in safetensors format and unfreeze specific layers for fine-tuning.
+ """
+ assert self.model.model_head is not None
+ print("Loading pretrained safetensors checkpoint")
+ state_dict = load_file(model_ckpt)
+ self.load_state_dict(state_dict, strict=False)
+
+ for name, param in self.model.named_parameters():
+ if self.hparams.finetuning.freeze_layers:
+ param.requires_grad = False
+ if "model_head" in name:
+ param.requires_grad = True
+
+ print("Pretrained model ready.")
+
+ def generate_fake_mask(self, batch_size, C, T):
+ """
+ Create a dummy mask tensor to simulate attention masking.
+
+ Args:
+ batch_size (int): Number of samples.
+ C (int): Number of channels.
+ T (int): Temporal dimension.
+
+ Returns:
+ torch.Tensor: Boolean mask tensor of shape (B, C, T).
+ """
+ return torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device)
+
+ def _step(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> dict:
+ """
+ Perform forward pass and post-process predictions.
+
+ Args:
+ X (torch.Tensor): Input tensor.
+
+ Returns:
+ dict: Dictionary containing predicted labels, probabilities, and logits.
+ """
+ y_pred_logits, _ = self.model(X, mask=mask)
+
+ if self.classification_type in ("bc", "ml"):
+ y_pred_probs = torch.softmax(y_pred_logits, dim=1)
+ y_pred_label = torch.argmax(y_pred_probs, dim=1)
+
+ else:
+ raise NotImplementedError(f"No valid classification type: {self.classification_type}")
+
+ return {
+ "label": y_pred_label,
+ "probs": y_pred_probs,
+ "logits": y_pred_logits,
+ }
+
+ def training_step(self, batch, batch_idx):
+ X, y = batch
+ if self.normalize:
+ X = self.normalize_fct(X)
+ mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2])
+ y_pred = self._step(X, mask=mask)
+ loss = self.criterion(y_pred["logits"], y)
+
+ self.train_label_metrics(y_pred["label"], y)
+ self.train_logit_metrics(self._handle_binary(y_pred["logits"]), y)
+ self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False)
+ self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False)
+ self.log(
+ "train_loss",
+ loss,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ logger=True,
+ sync_dist=True,
+ )
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ X, y = batch
+ if self.normalize:
+ X = self.normalize_fct(X)
+ mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2])
+ y_pred = self._step(X, mask=mask)
+ loss = self.criterion(y_pred["logits"], y)
+
+ self.val_label_metrics(y_pred["label"], y)
+ self.val_logit_metrics(self._handle_binary(y_pred["logits"]), y)
+ self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True)
+ self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True)
+ self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True)
+ return loss
+
+ def test_step(self, batch, batch_idx):
+ X, y = batch
+ if self.normalize:
+ X = self.normalize_fct(X)
+ mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2])
+ y_pred = self._step(X, mask=mask)
+ loss = self.criterion(y_pred["logits"], y)
+
+ self.test_label_metrics(y_pred["label"], y)
+ self.test_logit_metrics(self._handle_binary(y_pred["logits"]), y)
+ self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True)
+ self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True)
+ self.log("test_loss", loss, prog_bar=True, logger=True, sync_dist=True)
+ return loss
+
+ def lr_scheduler_step(self, scheduler, metric):
+ """
+ Custom scheduler step function for step-based LR schedulers
+ """
+ scheduler.step(epoch=self.current_epoch)
+
+ def configure_optimizers(self):
+ """
+ Configure the optimizer and learning rate scheduler.
+
+ Returns:
+ dict: Configuration dictionary with optimizer and LR scheduler.
+ """
+ num_blocks = self.hparams.model.n_layer
+ params_to_pass = []
+ base_lr = self.hparams.optimizer.lr
+ decay_factor = self.hparams.layerwise_lr_decay
+
+ for name, param in self.model.named_parameters():
+ lr = base_lr
+ if "mamba_blocks" in name or "norm_layers" in name:
+ block_nr = int(name.split(".")[1])
+ lr *= decay_factor ** (num_blocks - block_nr)
+ params_to_pass.append({"params": param, "lr": lr})
+
+ if self.hparams.optimizer.optim == "SGD":
+ optimizer = torch.optim.SGD(params_to_pass, lr=base_lr, momentum=self.hparams.optimizer.momentum)
+ elif self.hparams.optimizer.optim == "Adam":
+ optimizer = torch.optim.Adam(
+ params_to_pass,
+ lr=base_lr,
+ weight_decay=self.hparams.optimizer.weight_decay,
+ )
+ elif self.hparams.optimizer.optim == "AdamW":
+ optimizer = torch.optim.AdamW(
+ params_to_pass,
+ lr=base_lr,
+ weight_decay=self.hparams.optimizer.weight_decay,
+ betas=self.hparams.optimizer.betas,
+ )
+ elif self.hparams.optimizer.optim == "LAMB":
+ optimizer = torch_optim.Lamb(params_to_pass, lr=base_lr)
+ else:
+ raise NotImplementedError("No valid optimizer name")
+
+ if self.hparams.scheduler_type == "multi_step_lr":
+ scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer)
+ else:
+ scheduler = hydra.utils.instantiate(
+ self.hparams.scheduler,
+ optimizer=optimizer,
+ total_training_opt_steps=self.trainer.estimated_stepping_batches,
+ )
+
+ lr_scheduler_config = {
+ "scheduler": scheduler,
+ "interval": "epoch",
+ "frequency": 1,
+ "monitor": "val_loss",
+ }
+
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
+
+ def _handle_binary(self, preds):
+ """
+ Special handling for binary classification probabilities.
+
+ Args:
+ preds (torch.Tensor): Logit outputs.
+
+ Returns:
+ torch.Tensor: Probabilities for the positive class.
+ """
+ if self.classification_task == "binary" and self.classification_type != "mc":
+ return preds[:, 1].squeeze()
+ else:
+ return preds
diff --git a/tasks/pretrain_task_EMG.py b/tasks/pretrain_task_EMG.py
new file mode 100644
index 0000000..c535075
--- /dev/null
+++ b/tasks/pretrain_task_EMG.py
@@ -0,0 +1,295 @@
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch_optimizer as torch_optim
+from omegaconf import DictConfig
+
+from util.train_utils import MinMaxNormalization
+
+
+class MaskTask(pl.LightningModule):
+ """
+ PyTorch Lightning module for training a model with masked reconstruction.
+
+ Args:
+ hparams (DictConfig): Parameters and configurations loaded via Hydra.
+ """
+
+ def __init__(self, hparams: DictConfig):
+ super().__init__()
+ self.save_hyperparameters(hparams)
+ self.model = hydra.utils.instantiate(self.hparams.model)
+ self.criterion = hydra.utils.instantiate(self.hparams.criterion)
+ self.patch_size = self.hparams.masking.patch_size
+ self.masking_ratio = self.hparams.masking.masking_ratio
+ self.unmasked_loss_coeff = self.hparams.masking.unmasked_loss_coeff
+
+ # Enable normalization if specified in parameters
+ self.normalize = False
+ if "input_normalization" in self.hparams and self.hparams.input_normalization.normalize:
+ self.normalize = True
+ self.normalize_fct = MinMaxNormalization()
+
+ def generate_mask(self, batch_size, C, T):
+ """
+ Generate per-sample patch-level boolean masks (MAE-style).
+
+ Returns:
+ mask_full (torch.BoolTensor): Shape (B, C, T)
+ True = masked element
+ """
+ patch_H, patch_W = self.patch_size
+ num_patches_H = C // patch_H
+ num_patches_W = T // patch_W
+ N = num_patches_H * num_patches_W
+
+ # Number of patches to mask per sample
+ num_to_mask = int(N * self.masking_ratio)
+
+ # Generate patch-level mask (B, N) - vectorized
+ mask_patches = torch.zeros(batch_size, N, dtype=torch.bool, device=self.device)
+
+ for b in range(batch_size):
+ selected = torch.randperm(N, device=self.device)[:num_to_mask]
+ mask_patches[b, selected] = True
+
+ # unpatchify using reshape and repeat_interleave
+ # (B, N) -> (B, num_patches_H, num_patches_W)
+ mask_patches_2d = mask_patches.reshape(batch_size, num_patches_H, num_patches_W)
+
+ # Expand to full shape using repeat_interleave
+ # (B, num_patches_H, num_patches_W) -> (B, C, T)
+ mask_full = mask_patches_2d.repeat_interleave(patch_H, dim=1).repeat_interleave(patch_W, dim=2)
+
+ return mask_full
+
+ def unpatchify(self, x_patches: torch.Tensor, in_chans: int) -> torch.Tensor:
+ """
+ Convert patch embeddings (B, N, P) back to waveform (B, C, T)
+
+ Args:
+ x_patches: (B, N, P)
+ in_chans: number of channels C
+ Returns:
+ x_reconstructed: (B, C, T)
+ """
+ B, N, P = x_patches.shape
+ num_patches_per_chan = N // in_chans
+ x_recon = x_patches.reshape(B, in_chans, num_patches_per_chan * P)
+ return x_recon
+
+ def training_step(self, batch, batch_idx):
+ """
+ Training step: apply mask, normalize and compute loss.
+
+ Args:
+ batch (torch.Tensor): Input batch.
+ batch_idx (int): Batch index.
+
+ Returns:
+ torch.Tensor: Loss value.
+ """
+ X = batch
+ mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2])
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ x_reconstructed, x_original = self.model(X, mask=mask) # x_reconstructed: (B, N, P)
+
+ # unpatchify to original signal shape (B, C, T)
+ x_reconstructed_unpatched = self.unpatchify(x_reconstructed, self.hparams.model.in_chans)
+
+ # Compute loss on masked parts and unmasked parts (with coefficient)
+ masked_loss, unmasked_loss = self.criterion(x_reconstructed_unpatched, x_original, mask)
+ loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss
+
+ self.log(
+ "train_loss",
+ masked_loss,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ logger=True,
+ )
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ """
+ Validation step: apply mask, normalize, compute loss and log signals.
+
+ Args:
+ batch (torch.Tensor): Input batch.
+ batch_idx (int): Batch index.
+
+ Returns:
+ torch.Tensor: Loss value.
+ """
+ X = batch
+ mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2])
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ x_reconstructed, x_original = self.model(X, mask=mask) # x_reconstructed: (B, N, P)
+
+ # unpatchify to original signal shape (B, C, T)
+ x_reconstructed_unpatched = self.unpatchify(x_reconstructed, self.hparams.model.in_chans)
+
+ # Compute loss on masked parts and unmasked parts (with coefficient)
+ masked_loss, unmasked_loss = self.criterion(x_reconstructed_unpatched, x_original, mask)
+ loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss
+
+ self.log(
+ "val_loss",
+ loss,
+ prog_bar=True,
+ on_step=False,
+ on_epoch=True,
+ logger=True,
+ sync_dist=True,
+ )
+
+ # Fixed indices for logging signals
+ random_indices = [6, 16, 30]
+
+ # Log signals with mask only for the first validation batch
+ if batch_idx == 0:
+ self.log_signals_with_mask(
+ x_original.float(),
+ x_reconstructed_unpatched.float(),
+ mask,
+ batch_indices=random_indices,
+ batch_idx=batch_idx,
+ )
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Configure optimizer and scheduler based on parameters.
+
+ Returns:
+ dict: Dictionary with optimizer and scheduler for PyTorch Lightning.
+ """
+ if self.hparams.optimizer.optim == "SGD":
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9)
+ elif self.hparams.optimizer.optim == "Adam":
+ optimizer = torch.optim.Adam(
+ self.model.parameters(),
+ lr=self.hparams.optimizer.lr,
+ weight_decay=self.hparams.optimizer.weight_decay,
+ )
+ elif self.hparams.optimizer.optim == "AdamW":
+ optimizer = torch.optim.AdamW(
+ self.model.parameters(),
+ lr=self.hparams.optimizer.lr,
+ weight_decay=self.hparams.optimizer.weight_decay,
+ )
+ elif self.hparams.optimizer.optim == "LAMB":
+ optimizer = torch_optim.Lamb(
+ self.model.parameters(),
+ lr=self.hparams.optimizer.lr,
+ )
+ else:
+ raise NotImplementedError("No valid optim name")
+
+ scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer)
+ lr_scheduler_config = {
+ "scheduler": scheduler,
+ "interval": "epoch",
+ "frequency": 1,
+ "monitor": "val_loss",
+ }
+
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
+
+ def lr_scheduler_step(self, scheduler, metric):
+ scheduler.step(epoch=self.current_epoch)
+
+ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indices=None, batch_idx=None):
+ """
+ Log original and reconstructed signals highlighting masked regions.
+
+ Args:
+ original (torch.Tensor): Original signals.
+ reconstructed (torch.Tensor): Signals reconstructed by the model.
+ mask (torch.BoolTensor, optional): Applied mask.
+ batch_indices (list[int], optional): Batch indices to log.
+ batch_idx (int, optional): Current batch index.
+ """
+ patch_H, patch_W = self.patch_size
+ batch_size, C, T = original.shape
+
+ for batch_idx in batch_indices:
+ original_signal = original[batch_idx]
+ reconstructed_signal = reconstructed[batch_idx]
+
+ fig, ax = plt.subplots(1, 1, figsize=(15, 6))
+
+ # Limit visualization to the first patch_H channels
+ original_signal_c2 = original_signal[:patch_H, :]
+ reconstructed_signal_c2 = reconstructed_signal[:patch_H, :]
+
+ ax.plot(
+ original_signal_c2[0].cpu().numpy(),
+ label="Original Channel 0",
+ color="blue",
+ alpha=0.7,
+ )
+ ax.plot(
+ reconstructed_signal_c2[0].cpu().numpy(),
+ label="Reconstructed Channel 0",
+ color="orange",
+ alpha=0.7,
+ )
+
+ if mask is not None:
+ mask_c2 = mask[batch_idx, :patch_H, :]
+ indices = []
+
+ # Highlight masked regions with a light gray transparent band
+ for i in range(patch_H):
+ for j in range(T // patch_W):
+ if mask_c2[i, j * patch_W : (j + 1) * patch_W].all():
+ ax.axvspan(
+ j * patch_W,
+ (j + 1) * patch_W,
+ color="lightgray",
+ alpha=0.1,
+ )
+ indices.append(j)
+
+ # Remove duplicates and sort highlighted indices
+ indices_array = np.array(indices)
+ indices_array = np.unique(indices_array)
+
+ ax.set_title(f"Signal Reconstruction - batch_ {batch_idx}")
+ ax.legend()
+
+ # Log the figure on TensorBoard with batch and index in the title
+ self.logger.experiment.add_figure(
+ f"Original and Reconstructed Signals with Mask (batch_0_ {batch_idx}, F1 = 0)",
+ fig,
+ self.current_epoch,
+ )
+ plt.close(fig)
diff --git a/util/ckpt_to_safetensor.py b/util/ckpt_to_safetensor.py
new file mode 100644
index 0000000..2441f74
--- /dev/null
+++ b/util/ckpt_to_safetensor.py
@@ -0,0 +1,74 @@
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
+import argparse
+
+import torch
+from safetensors.torch import save_file
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Convert a PyTorch Lightning checkpoint to a safetensors file."
+ )
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ required=True,
+ help="Path to the PyTorch Lightning checkpoint file.",
+ )
+ parser.add_argument(
+ "--safetensor_path",
+ type=str,
+ default="model.safetensors",
+ help="Path to save the converted safetensors file.",
+ )
+ parser.add_argument(
+ "--exclude_keys",
+ type=str,
+ nargs="*",
+ default=[],
+ help="List of keys to exclude from the safetensors file.",
+ )
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ help="If set, print the keys of the parameters being saved.",
+ )
+ args = parser.parse_args()
+
+ # Load the PyTorch Lightning checkpoint
+ ckpt = torch.load(args.ckpt_path, map_location="cpu", weights_only=False)
+ # Extract the model state_dict and filter out excluded keys if any
+ parameters = {
+ k: v for k, v in ckpt["state_dict"].items() if k not in args.exclude_keys
+ }
+
+ # Options for verbose output - list the keys being saved
+ if args.verbose:
+ print("The following keys will be saved in the safetensors file:")
+ for key in parameters.keys():
+ print(f" - {key}")
+
+ # Save the parameters in safetensors format
+ safetensor_path = (
+ args.safetensor_path
+ if args.safetensor_path.endswith(".safetensors")
+ else args.safetensor_path + ".safetensors"
+ )
+ save_file(parameters, safetensor_path)
+ print(f"Safetensors file saved to {safetensor_path}")
diff --git a/util/train_utils.py b/util/train_utils.py
index 4bdbab9..b44f634 100644
--- a/util/train_utils.py
+++ b/util/train_utils.py
@@ -1,30 +1,31 @@
-#*----------------------------------------------------------------------------*
-#* Copyright (C) 2025 ETH Zurich, Switzerland *
-#* SPDX-License-Identifier: Apache-2.0 *
-#* *
-#* Licensed under the Apache License, Version 2.0 (the "License"); *
-#* you may not use this file except in compliance with the License. *
-#* You may obtain a copy of the License at *
-#* *
-#* http://www.apache.org/licenses/LICENSE-2.0 *
-#* *
-#* Unless required by applicable law or agreed to in writing, software *
-#* distributed under the License is distributed on an "AS IS" BASIS, *
-#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
-#* See the License for the specific language governing permissions and *
-#* limitations under the License. *
-#* *
-#* Author: Thorir Mar Ingolfsson *
-#* Author: Anna Tegon *
-#* Author: Berkay Döner *
-#*----------------------------------------------------------------------------*
+# *----------------------------------------------------------------------------*
+# * Copyright (C) 2025 ETH Zurich, Switzerland *
+# * SPDX-License-Identifier: Apache-2.0 *
+# * *
+# * Licensed under the Apache License, Version 2.0 (the "License"); *
+# * you may not use this file except in compliance with the License. *
+# * You may obtain a copy of the License at *
+# * *
+# * http://www.apache.org/licenses/LICENSE-2.0 *
+# * *
+# * Unless required by applicable law or agreed to in writing, software *
+# * distributed under the License is distributed on an "AS IS" BASIS, *
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+# * See the License for the specific language governing permissions and *
+# * limitations under the License. *
+# * *
+# * Author: Thorir Mar Ingolfsson *
+# * Author: Anna Tegon *
+# * Author: Berkay Döner *
+# * Author: Matteo Fasulo *
+# *----------------------------------------------------------------------------*
-from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
-from typing import Optional
import os
import os.path as osp
+from typing import Optional
+
import torch
-import torch.nn as nn
+from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
def find_last_checkpoint_path(checkpoint_dir: Optional[str]) -> Optional[str]:
@@ -39,6 +40,7 @@ def find_last_checkpoint_path(checkpoint_dir: Optional[str]) -> Optional[str]:
return last_checkpoint_filepath
+
class RobustQuartileNormalize:
def __init__(self, q_lower, q_upper):
self.q_lower = q_lower
@@ -46,4 +48,16 @@ def __init__(self, q_lower, q_upper):
def __call__(self, tensor):
iqr = self.q_upper - self.q_lower
- return (tensor - self.q_lower) / (iqr + 1e-8)
\ No newline at end of file
+ return (tensor - self.q_lower) / (iqr + 1e-8)
+
+
+class MinMaxNormalization:
+ def __init__(self, eps: float = 1e-8):
+ self.eps = eps
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, T)
+ maxv = x.amax(dim=-1, keepdim=True)
+ minv = x.amin(dim=-1, keepdim=True)
+ x = (x - minv) / (maxv - minv + self.eps)
+ return (x - 0.5) * 2 # scale to [-1, 1]