A python package dedicated to classification and exploration of Cell Painting data. This package relies on lightning for training/evaluation of the DenseNet model.
- Ensure
Python >=3.11andcondaare installed on your machine The recommended installer forcondaisminiforge - Clone this repository
$ git clone https://github.com/jhuapl-bio/DeepPaint.git
- Navigate to the DeepPaint directory (containing the README)
$ cd DeepPaint - Create a
condavirtual environment from theenvironment.ymlfile and activate it$ conda env create -n <env_name> -f environment.yml $ conda activate <env_name>
- Install the
DeepPaintpackage with pip$ pip install .
The DeepPaint package can be run as a module with the command python -m deep_paint to invoke the CLI. This is the entry point for training and evaluating models.
Four commands are available:
fit: Train or finetune a modelvalidate: Run one evaluation epoch on a validation settest: Run one test epoch on a test setpredict: Get predictions from a trained model on part or all of a dataset
These commands correspond to the lightning.pytorch.Trainer methods. All commands can be run with the --config argument to specify a configuration file.
The configuration files used for training, getting model predictions, and getting model embeddings are available in the configs directory. Ensure to update the paths in the configuration files (they are commented for convenience).
The configuration file is a YAML file that contains all the necessary parameters for training, evaluating, or testing a model. The YAML file is divided into the following fields:
| Field | Subclass | Description | Required? |
|---|---|---|---|
| model | LightningModule |
Model architecture and hyperparameters | ✅ |
| data | LightningDataModule |
Data preprocessing and augmentation | ✅ |
| trainer | Trainer |
Training arguments | ✅ |
| optimizer | Optimizer |
Optimizer | ❌ |
| lr_scheduler | LRScheduler |
Learning Rate Scheduler | ❌ |
| ckpt_path | N/A | Path to model checkpoint | ❌ |
All fields except trainer and ckpt_path require a class_path parameter. A full path to the class must be provided. Following this parameter, the rest of the field is parsed as keyword arguments to the class constructor via the init_args parameter.
- Train a model:
python -m deep_paint fit --config /path/to/your_config.yaml
- Run a validation epoch:
python -m deep_paint validate --config /path/to/your_config.yaml
- Run a test epoch:
python -m deep_paint test --config /path/to/your_config.yaml - Get model predictions:
python -m deep_paint predict --config /path/to/your_config.yaml
A custom script has been created to extract embeddings from a trained model. The script can be run with the following command:
python -m deep_paint.utils.embeddings --config /path/to/your_config.yamlThis config file looks slightly different than the config file used for the four main commands. Refer to the configs directory for examples.
The results directory contains the following subdirectories:
checkpoints: Contains model checkpointsconfigs: Contains configuration files used for training, getting model predictions, and getting model embeddingsembeddings: Contains embeddings extracted from the model on the test set of theRxRx2datalogs: Contains csv files extracted fromtensorboardlogsmetadata: Contains custom metadata used for training theDenseNetmodelpredictions: Contains model predictions on the test set of theRxRx2data
The RxRx2 dataset was used for training and evaluation of the DenseNet model. The dataset is freely available to download from the RxRx.ai website.
The checkpoints directory contains model checkpoints for the binary and multiclass DenseNet model. These checkpoints can be used to load the trained models and make predictions.
The notebooks directory contains Jupyter notebooks that demonstrate the performance of the DenseNet model on the RxRx2 dataset. The notebooks contain visualizations of the model predictions and embeddings.
