Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![Documentation](https://img.shields.io/badge/docs-read%20the%20docs-blue)](https://generativeproteomics.readthedocs.io/en/latest/)
[![HuggingFace](https://img.shields.io/badge/Hugging_Face-grey?style=flat&logo=huggingface&color=grey)](https://huggingface.co/QuantitativeBiology)

**GainPro** is a PyTorch implementation of Generative Adversarial Imputation Networks (GAIN) [[1]](#1) for imputing missing iBAQ values in proteomics datasets. The package provides a unified command-line interface with multiple imputation methods including basic GAIN, GAIN-DANN (domain-adaptive), and pre-trained HuggingFace models.
**GainPro** is a PyTorch implementation of Generative Adversarial Imputation Networks (GAIN) [[1]](#1) for imputing missing iBAQ values in proteomics datasets. The package provides a unified command-line interface with multiple imputation methods including basic GAIN, GAIN-DANN (domain-adaptive), and pre-trained GAIN-DANN models from HuggingFace.

## Table of Contents

Expand All @@ -22,7 +22,7 @@

- **Basic GAIN**: Simple Generator + Discriminator architecture for general-purpose imputation
- **GAIN-DANN**: Domain-adaptive imputation with Encoder/Decoder architecture
- **Pre-trained Models**: Easy access to HuggingFace pre-trained models
- **Pre-trained Models**: Easy access to HuggingFace pre-trained GAIN-DANN models
- **Median Imputation**: Simple baseline method
- **Flexible CLI**: Unified `gainpro` command with intuitive subcommands
- **Python API**: Full programmatic access to all functionality
Expand Down Expand Up @@ -141,14 +141,21 @@ Use a trained GAIN-DANN checkpoint for imputation:
gainpro impute --checkpoint checkpoints/your_model --input data.csv --output imputed.csv
```

### `gainpro download` - HuggingFace Pre-trained Models
### `gainpro download` - HuggingFace Pre-trained GAIN-DANN Models

Download and use pre-trained models from HuggingFace:
Download and use pre-trained GAIN-DANN models from HuggingFace:

```bash
gainpro download --input data.csv --output imputed.csv
```

**Note:** This command is specifically designed for GAIN-DANN models. It requires the HuggingFace repository to contain:
- `config.json` - Model configuration
- `pytorch_model.bin` - Model weights
- `modeling_gain_dann.py` - Model architecture file with `GainDANN` class

The model must follow the GAIN-DANN interface (returning `(x_reconstructed, x_domain)` from forward pass). Other model types are not currently supported.

### `gainpro median` - Median Imputation

Simple median imputation baseline:
Expand Down
95 changes: 65 additions & 30 deletions gainpro/gainpro.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,18 @@ def train(config_file, save):
params_gain_dann = ParamsGainDann.read_hyperparameters(config_file)
params = params_gain_dann.to_dict()

# Read dataset
# Read dataset - use DataDANN for GAIN-DANN models
if params_gain_dann.path_dataset_missing:
logger.info(f"Loading dataset with missing values: {params_gain_dann.path_dataset_missing}")
dataset_missing = pd.read_csv(params_gain_dann.path_dataset_missing, index_col=0)
dataset_missing = dataset_missing.iloc[:, 8500:] # TODO: Remove start_col hardcoding
data = Data(
data = DataDANN(
dataset_path=params_gain_dann.path_dataset,
dataset_missing=dataset_missing,
start_col=8500
)
else:
data = Data(
data = DataDANN(
dataset_path=params_gain_dann.path_dataset,
miss_rate=params["miss_rate"],
start_col=8500
Expand Down Expand Up @@ -252,8 +252,8 @@ def impute(checkpoint_dir, input_file, output_file, miss_rate):
df = df.loc[:, list(common_proteins)]
logger.info(f"Using {len(common_proteins)} common proteins")

# Create data object
data = Data(df, miss_rate=miss_rate, start_col=0)
# Create data object - use DataDANN for GAIN-DANN models
data = DataDANN(dataset=df, miss_rate=miss_rate, start_col=0)
incomplete_data = data.dataset_missing

# Pad with NaNs for model compatibility
Expand Down Expand Up @@ -303,47 +303,73 @@ def impute(checkpoint_dir, input_file, output_file, miss_rate):
@click.option(
"--model-id",
default="QuantitativeBiology/GAIN_DANN_model",
help="HuggingFace model repository ID",
help="HuggingFace model repository ID (must be a GAIN-DANN model with modeling_gain_dann.py file)",
show_default=True,
)
def download(input_file, output_file, model_id):
"""
Download a pre-trained model from HuggingFace and perform imputation.
Download a pre-trained GAIN-DANN model from HuggingFace and perform imputation.

NOTE: This command is specifically designed for GAIN-DANN models. It expects:
- A config.json file with model configuration
- A pytorch_model.bin file with model weights
- A modeling_gain_dann.py file with the model architecture class

The model must implement the GainDANN interface with forward() returning
(x_reconstructed, x_domain) tuples. Other model types are not supported.

Example:

gainpro download --input data.csv --output imputed.csv
"""
logger.info(f"Downloading model from HuggingFace: {model_id}")
logger.info(f"Downloading GAIN-DANN model from HuggingFace: {model_id}")
logger.warning(
"NOTE: This command only works with GAIN-DANN models that include "
"modeling_gain_dann.py. Other model types are not supported."
)

save_dir = "./GAIN_DANN_model"
os.makedirs(save_dir, exist_ok=True)

# Download files from HuggingFace
# NOTE: This is specific to GAIN-DANN model structure
logger.info("Downloading model files...")
config_path = hf_hub_download(
repo_id=model_id,
filename="config.json",
cache_dir=save_dir
)
weights_path = hf_hub_download(
repo_id=model_id,
filename="pytorch_model.bin",
cache_dir=save_dir
)
model_path = hf_hub_download(
repo_id=model_id,
filename="modeling_gain_dann.py",
cache_dir=save_dir
)
try:
config_path = hf_hub_download(
repo_id=model_id,
filename="config.json",
cache_dir=save_dir
)
weights_path = hf_hub_download(
repo_id=model_id,
filename="pytorch_model.bin",
cache_dir=save_dir
)
model_path = hf_hub_download(
repo_id=model_id,
filename="modeling_gain_dann.py",
cache_dir=save_dir
)
except Exception as e:
raise click.ClickException(
f"Failed to download GAIN-DANN model files from {model_id}. "
f"Ensure the repository contains config.json, pytorch_model.bin, "
f"and modeling_gain_dann.py files. Error: {e}"
)

# Add directory to Python path to import the model
directory = os.path.dirname(model_path)
if directory not in sys.path:
sys.path.append(directory)

# Import model classes
from modeling_gain_dann import GainDANNConfig, GainDANN
# Import model classes (GAIN-DANN specific)
try:
from modeling_gain_dann import GainDANNConfig, GainDANN
except ImportError as e:
raise click.ClickException(
f"Failed to import GAIN-DANN model classes from modeling_gain_dann.py. "
f"This command only works with GAIN-DANN models. Error: {e}"
)

logger.info("Loading model configuration...")

Expand Down Expand Up @@ -383,10 +409,16 @@ def download(input_file, output_file, model_id):

x = torch.tensor(data_df.values, dtype=torch.float32)

# Perform imputation
# Perform imputation (GAIN-DANN specific: returns reconstructed data and domain predictions)
logger.info("Running imputation...")
with torch.no_grad():
x_reconstructed, x_domain = model(x)
try:
x_reconstructed, x_domain = model(x)
except (ValueError, TypeError) as e:
raise click.ClickException(
f"Model output format not recognized. This command expects GAIN-DANN models "
f"that return (x_reconstructed, x_domain) tuples. Error: {e}"
)

# Convert to DataFrame and save
result_df = pd.DataFrame(x_reconstructed.numpy(), columns=data_df.columns if hasattr(data_df, 'columns') else None)
Expand All @@ -397,7 +429,7 @@ def download(input_file, output_file, model_id):

result_df.to_csv(output_file, index=True)

# Optionally save domain predictions
# Save domain predictions (GAIN-DANN specific feature)
domain_file = output_file.replace(".csv", "_domain.csv")
pd.DataFrame(x_domain.numpy()).to_csv(domain_file, index=False)

Expand Down Expand Up @@ -790,8 +822,11 @@ def gain_main():

# Show deprecation warning
click.echo(
"⚠️ WARNING: The 'gain' command is deprecated. "
"Please use 'gainpro gain' instead.\n",
"⚠️ WARNING: The standalone 'gain' command is deprecated and will be "
"removed in version 0.3.0.\n"
" Please migrate to 'gainpro gain' instead. The functionality is identical.\n"
" For documentation and examples, see: "
"https://github.com/bigbio/GainPro/blob/main/README.md\n",
err=True
)

Expand Down
2 changes: 1 addition & 1 deletion use-case/1-pip_install/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ to perform imputation of missing values of proteomics' datasets.
It is currently based on the `Generative Adversarial Imputation Network (GAIN)` architecture.
To use the package, you need to have `Python 3.10` or `Python 3.11` on your system.
To do that, you can create a conda environment, for example.
The package is available on `PyPI` and can be installed using a `pip` command (gainpro 0.2.1).
The package is available on `PyPI` and can be installed using a `pip` command (gainpro 0.2.0).

```bash
pip install gainpro
Expand Down