Skip to content

Commit 195da2c

Browse files
committed
Added ESM-2 language model for larger peptide generation
1 parent fb4cb28 commit 195da2c

9 files changed

+265
-39
lines changed

.DS_Store

0 Bytes
Binary file not shown.

CLAUDE.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
This is a bioinformatics tool for generating control peptides used in neoantigen analysis and benchmarking with pVACtools. The project targets cancer immunotherapy research by providing high-quality reference peptide sets.
8+
9+
## Core Architecture
10+
11+
The project consists of two main components:
12+
13+
1. **CLI Tool** (`generate_control_peptides.py`) - Main peptide generation engine
14+
2. **GUI Interface** (`peptide_gui.py`) - User-friendly wrapper using PySimpleGUI
15+
16+
### Peptide Generation Methods
17+
18+
Three generation strategies are implemented as separate functions:
19+
20+
- `generate_random_peptides()` - Random amino acid sequences from 20 standard amino acids
21+
- `sample_peptides_from_fasta()` - Extracts subsequences from existing protein sequences
22+
- `generate_llm_peptides()` - Uses HuggingFace protein language models (ProtGPT2 or ESM-2) for biologically plausible sequences
23+
24+
### Key Design Patterns
25+
26+
- **Input validation**: Comprehensive parameter bounds checking (length 1-50, count 1-10M)
27+
- **Error handling**: Graceful degradation with informative error messages
28+
- **Reproducibility**: Random seed support for deterministic results
29+
- **Batch processing**: Optimized generation for large peptide counts
30+
31+
## Development Commands
32+
33+
### Setup
34+
```bash
35+
pip install -r requirements.txt
36+
```
37+
38+
### Testing the CLI
39+
```bash
40+
# Test random generation
41+
python generate_control_peptides.py --length 9 --count 10 --source random --output test.fasta --seed 42
42+
43+
# Test FASTA sampling (requires protein sequences in data/)
44+
python generate_control_peptides.py --length 9 --count 10 --source fasta --fasta_file data/GCF_000001405.40/protein.faa --output test.fasta
45+
46+
# Test ProtGPT2 generation (good for short peptides, 8-12 AA)
47+
python generate_control_peptides.py --length 9 --count 5 --source llm --llm_model protgpt2 --temperature 1.2 --output test.fasta
48+
49+
# Test ESM-2 generation (better for longer peptides, 10+ AA)
50+
python generate_control_peptides.py --length 15 --count 5 --source llm --llm_model esm2 --temperature 1.0 --output test.fasta
51+
```
52+
53+
### Testing the GUI
54+
```bash
55+
python peptide_gui.py
56+
```
57+
58+
### Code Validation
59+
```bash
60+
# Check syntax
61+
python3 -m py_compile generate_control_peptides.py
62+
python3 -m py_compile peptide_gui.py
63+
64+
# View help
65+
python generate_control_peptides.py --help
66+
```
67+
68+
## Important Implementation Details
69+
70+
### Multi-LLM Architecture
71+
- **ProtGPT2**: 738M parameter model, token-based (~4 AA per token), good for short peptides (8-12 AA)
72+
- **ESM-2**: 650M parameter model, amino acid-level tokenization, better for longer peptides (10+ AA)
73+
- Modular design allows easy addition of new protein language models
74+
- Uses batch generation for efficiency (ProtGPT2: 50/batch, ESM-2: 10/batch)
75+
- Includes retry logic with model-specific attempt limits
76+
- Filters generated text to valid amino acids only
77+
78+
### GUI-CLI Communication
79+
The GUI builds command-line arguments and executes the CLI tool via subprocess, streaming output in real-time to the GUI text box.
80+
81+
### Input Validation Bounds
82+
- Peptide length: 1-50 amino acids
83+
- Count: 1-10,000,000 peptides
84+
- LLM parameters: temperature > 0, top_k ≥ 1, top_p (0,1], repetition_penalty > 0
85+
- Model selection: protgpt2 (recommended for ≤12 AA), esm2 (recommended for ≥10 AA)
86+
87+
### File Formats
88+
- Input: FASTA format (`.fasta`, `.faa`) for protein sequences
89+
- Output: FASTA format with headers like `>peptide_1`, `>peptide_2`, etc.
90+
91+
## Dependencies
92+
93+
Core dependencies are specified in `requirements.txt`:
94+
- PySimpleGUI (GUI framework)
95+
- transformers (protein language models)
96+
- torch (PyTorch backend)
97+
- sentencepiece (tokenization support)
98+
99+
The project uses only Python standard library beyond these requirements.

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ This project provides a robust, reproducible tool for generating large, high-qua
99
- **Three peptide generation modes:**
1010
- Random (from 20 amino acids)
1111
- Sampled from a user-supplied FASTA file
12-
- Generated by a protein language model (ProtGPT2 via HuggingFace Transformers)
12+
- Generated by protein language models (ProtGPT2 or ESM-2 via HuggingFace Transformers)
1313
- **FASTA output** compatible with pVACtools
1414
- **Command-line interface** for batch processing
1515
- **Simple GUI** (using PySimpleGUI) for non-technical users
@@ -32,7 +32,17 @@ pip install -r requirements.txt
3232
### Command Line
3333

3434
```bash
35-
python generate_control_peptides.py --mode [random|fasta|protGPT2] --length 9 --count 1000000 --output peptides.fasta [--fasta_file input.fasta] [--temperature 1.0] [--top_k 50]
35+
# Random generation
36+
python generate_control_peptides.py --source random --length 9 --count 1000 --output peptides.fasta --seed 42
37+
38+
# FASTA sampling
39+
python generate_control_peptides.py --source fasta --length 9 --count 1000 --fasta_file input.fasta --output peptides.fasta
40+
41+
# LLM generation (ProtGPT2 - good for short peptides)
42+
python generate_control_peptides.py --source llm --llm_model protgpt2 --length 9 --count 100 --output peptides.fasta
43+
44+
# LLM generation (ESM-2 - better for longer peptides)
45+
python generate_control_peptides.py --source llm --llm_model esm2 --length 15 --count 100 --output peptides.fasta --temperature 1.2
3646
```
3747

3848
### GUI
16.7 KB
Binary file not shown.
5.26 KB
Binary file not shown.

generate_control_peptides.py

Lines changed: 139 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,34 +42,140 @@ def sample_peptides_from_fasta(fasta_path: Path, length: int, count: int) -> Lis
4242
peptides.append(random.choice(all_subseqs))
4343
return peptides[:count]
4444

45-
def generate_protgpt2_peptides(length: int, count: int, temperature: float = 1.0, top_k: int = 950, top_p: float = 0.9, repetition_penalty: float = 1.2) -> List[str]:
45+
def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2", temperature: float = 1.0, top_k: int = 950, top_p: float = 0.9, repetition_penalty: float = 1.2) -> List[str]:
4646
try:
47-
from transformers import pipeline
47+
from transformers import pipeline, AutoTokenizer, AutoModel, AutoModelForMaskedLM
48+
import torch
49+
import torch.nn.functional as F
4850
except ImportError:
49-
print("Error: transformers package is required for ProtGPT2 generation. Please install with 'pip install transformers torch'", file=sys.stderr)
51+
print("Error: transformers and torch packages are required for LLM generation. Please install with 'pip install transformers torch'", file=sys.stderr)
5052
sys.exit(1)
51-
# Each token is ~4 amino acids, so for a peptide of length N, set max_length ≈ N/4 (rounded up)
52-
max_length = max(5, (length + 3) // 4) # ensure at least 1 token
53-
protgpt2 = pipeline('text-generation', model="nferruz/ProtGPT2", framework="pt")
53+
# Model configurations
54+
if model_name.lower() == "protgpt2":
55+
model_id = "nferruz/ProtGPT2"
56+
# Each token is ~4 amino acids, so for a peptide of length N, set max_length ≈ N/4 (rounded up)
57+
max_length = max(5, (length + 3) // 4) # ensure at least 1 token
58+
prompt = "<|endoftext|>"
59+
use_pipeline = True
60+
elif model_name.lower() == "esm2":
61+
model_id = "facebook/esm2_t12_35M_UR50D" # ESM-2 35M model (faster, smaller)
62+
max_length = length + 10 # ESM works with direct amino acid sequences
63+
use_pipeline = False
64+
else:
65+
print(f"Error: Unsupported model '{model_name}'. Supported models: protgpt2, esm2", file=sys.stderr)
66+
sys.exit(1)
67+
5468
peptides = []
5569
tries = 0
56-
batch_size = min(50, count) # Generate more peptides per batch for efficiency
57-
while len(peptides) < count and tries < count * 10:
58-
sequences = protgpt2("<|endoftext|>", max_length=max_length, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, num_return_sequences=min(count - len(peptides), batch_size), eos_token_id=0)
59-
if not sequences or not hasattr(sequences, '__iter__'):
60-
tries += 1
61-
continue
62-
for seq in sequences:
63-
if not isinstance(seq, dict):
64-
continue
65-
gen_text = seq.get('generated_text', '')
66-
if not isinstance(gen_text, str):
70+
batch_size = min(50, count) if model_name.lower() == "protgpt2" else min(10, count)
71+
72+
if use_pipeline:
73+
# ProtGPT2 pipeline approach
74+
llm_pipeline = pipeline('text-generation', model=model_id, framework="pt")
75+
while len(peptides) < count and tries < count * 10:
76+
sequences = llm_pipeline(prompt, max_length=max_length, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, num_return_sequences=min(count - len(peptides), batch_size), eos_token_id=0)
77+
if not sequences or not hasattr(sequences, '__iter__'):
78+
tries += 1
6779
continue
68-
# Remove whitespace and newlines, keep only valid amino acids
69-
pep = ''.join([c for c in gen_text if c in AMINO_ACIDS])
70-
if len(pep) == length:
71-
peptides.append(pep)
72-
tries += 1
80+
for seq in sequences:
81+
if not isinstance(seq, dict):
82+
continue
83+
gen_text = seq.get('generated_text', '')
84+
if not isinstance(gen_text, str):
85+
continue
86+
# Remove whitespace and newlines, keep only valid amino acids
87+
pep = ''.join([c for c in gen_text if c in AMINO_ACIDS])
88+
if len(pep) == length:
89+
peptides.append(pep)
90+
tries += 1
91+
else:
92+
# ESM-2 approach using masked language modeling
93+
print(f"Loading ESM-2 model (this may take a moment)...", file=sys.stderr)
94+
tokenizer = AutoTokenizer.from_pretrained(model_id)
95+
model = AutoModelForMaskedLM.from_pretrained(model_id)
96+
97+
while len(peptides) < count and tries < count * 10: # Reduce tries for efficiency
98+
try:
99+
for _ in range(min(count - len(peptides), batch_size)):
100+
# Simple approach: create a sequence with random masks
101+
# Start with a random seed and mask 40% of positions
102+
seed_length = max(1, length // 3)
103+
seed_seq = "".join(random.choices(AMINO_ACIDS, k=seed_length))
104+
105+
# Create masked sequence: seed + masks for remaining positions
106+
remaining_masks = length - seed_length
107+
masked_sequence = seed_seq + "<mask>" * remaining_masks
108+
109+
# Tokenize
110+
inputs = tokenizer(masked_sequence, return_tensors="pt", max_length=512, truncation=True)
111+
112+
with torch.no_grad():
113+
outputs = model(**inputs)
114+
predictions = outputs.logits
115+
116+
# Apply temperature
117+
predictions = predictions / temperature
118+
119+
# Find mask positions and generate amino acids
120+
mask_token_id = tokenizer.mask_token_id
121+
input_ids = inputs.input_ids[0]
122+
123+
generated_sequence = []
124+
token_idx = 0
125+
126+
for token_id in input_ids:
127+
if token_id == mask_token_id:
128+
# Get prediction for this mask
129+
logits = predictions[0, token_idx]
130+
131+
# Simple top-k sampling
132+
if top_k > 0 and top_k < logits.size(-1):
133+
top_k_logits, top_k_indices = torch.topk(logits, top_k)
134+
# Sample from top-k
135+
probs = F.softmax(top_k_logits, dim=-1)
136+
sampled_idx = torch.multinomial(probs, 1).item()
137+
sampled_token_id = top_k_indices[sampled_idx].item()
138+
else:
139+
# Sample from full distribution
140+
probs = F.softmax(logits, dim=-1)
141+
sampled_token_id = torch.multinomial(probs, 1).item()
142+
143+
sampled_token = tokenizer.decode([sampled_token_id])
144+
145+
# Only add valid amino acids
146+
if sampled_token in AMINO_ACIDS:
147+
generated_sequence.append(sampled_token)
148+
else:
149+
# Fallback to random amino acid
150+
generated_sequence.append(random.choice(AMINO_ACIDS))
151+
else:
152+
# Keep original token if it's an amino acid
153+
original_token = tokenizer.decode([token_id])
154+
if original_token in AMINO_ACIDS:
155+
generated_sequence.append(original_token)
156+
157+
token_idx += 1
158+
159+
# Create final peptide
160+
final_peptide = "".join(generated_sequence)
161+
162+
# Ensure exact length
163+
if len(final_peptide) >= length:
164+
final_peptide = final_peptide[:length]
165+
elif len(final_peptide) < length:
166+
# Pad with random amino acids if too short
167+
final_peptide += "".join(random.choices(AMINO_ACIDS, k=length - len(final_peptide)))
168+
169+
if len(final_peptide) == length:
170+
peptides.append(final_peptide)
171+
172+
except Exception as e:
173+
print(f"Warning: ESM-2 generation error: {e}", file=sys.stderr)
174+
# Fallback to random generation for this attempt
175+
fallback_peptide = "".join(random.choices(AMINO_ACIDS, k=length))
176+
peptides.append(fallback_peptide)
177+
178+
tries += 1
73179
if len(peptides) < count:
74180
print(f"Warning: Only generated {len(peptides)} peptides of requested {count} with exact length {length}.", file=sys.stderr)
75181
return peptides[:count]
@@ -83,14 +189,15 @@ def main():
83189
parser = argparse.ArgumentParser(description="Generate control peptides for neoantigen analysis.")
84190
parser.add_argument('--length', type=int, required=True, help='Peptide length (e.g., 8, 9, 10)')
85191
parser.add_argument('--count', type=int, required=True, help='Number of peptides to generate')
86-
parser.add_argument('--source', choices=['random', 'fasta', 'protgpt2'], required=True, help='Source of peptides: random, fasta, or protgpt2')
192+
parser.add_argument('--source', choices=['random', 'fasta', 'llm'], required=True, help='Source of peptides: random, fasta, or llm')
193+
parser.add_argument('--llm_model', choices=['protgpt2', 'esm2'], default='protgpt2', help='LLM model to use for generation (protgpt2 or esm2)')
87194
parser.add_argument('--fasta_file', type=Path, help='Path to reference FASTA file (required if source is fasta)')
88195
parser.add_argument('--output', type=Path, default=Path('control_peptides.fasta'), help='Output FASTA file')
89-
parser.add_argument('--seed', type=int, help='Random seed for reproducibility (not used for protgpt2)')
90-
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for ProtGPT2 generation (higher = more random)')
91-
parser.add_argument('--top_k', type=int, default=950, help='Top-k sampling for ProtGPT2')
92-
parser.add_argument('--top_p', type=float, default=0.9, help='Top-p (nucleus) sampling for ProtGPT2')
93-
parser.add_argument('--repetition_penalty', type=float, default=1.2, help='Repetition penalty for ProtGPT2')
196+
parser.add_argument('--seed', type=int, help='Random seed for reproducibility (not used for llm models)')
197+
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for LLM generation (higher = more random)')
198+
parser.add_argument('--top_k', type=int, default=950, help='Top-k sampling for LLM')
199+
parser.add_argument('--top_p', type=float, default=0.9, help='Top-p (nucleus) sampling for LLM')
200+
parser.add_argument('--repetition_penalty', type=float, default=1.2, help='Repetition penalty for LLM')
94201
args = parser.parse_args()
95202

96203
# Input validation
@@ -100,7 +207,7 @@ def main():
100207
if args.count < 1 or args.count > 10000000:
101208
print('Error: Count must be between 1 and 10,000,000 peptides', file=sys.stderr)
102209
sys.exit(1)
103-
if args.source == 'protgpt2':
210+
if args.source == 'llm':
104211
if args.temperature <= 0:
105212
print('Error: Temperature must be positive', file=sys.stderr)
106213
sys.exit(1)
@@ -114,7 +221,7 @@ def main():
114221
print('Error: Repetition penalty must be positive', file=sys.stderr)
115222
sys.exit(1)
116223

117-
if args.seed is not None and args.source != 'protgpt2':
224+
if args.seed is not None and args.source != 'llm':
118225
random.seed(args.seed)
119226

120227
if args.source == 'random':
@@ -124,8 +231,8 @@ def main():
124231
print('Error: --fasta_file is required when source is fasta', file=sys.stderr)
125232
sys.exit(1)
126233
peptides = sample_peptides_from_fasta(args.fasta_file, args.length, args.count)
127-
elif args.source == 'protgpt2':
128-
peptides = generate_protgpt2_peptides(args.length, args.count, args.temperature, args.top_k, args.top_p, args.repetition_penalty)
234+
elif args.source == 'llm':
235+
peptides = generate_llm_peptides(args.length, args.count, args.llm_model, args.temperature, args.top_k, args.top_p, args.repetition_penalty)
129236
else:
130237
print(f"Unknown source: {args.source}", file=sys.stderr)
131238
sys.exit(1)

0 commit comments

Comments
 (0)