@@ -42,34 +42,140 @@ def sample_peptides_from_fasta(fasta_path: Path, length: int, count: int) -> Lis
42
42
peptides .append (random .choice (all_subseqs ))
43
43
return peptides [:count ]
44
44
45
- def generate_protgpt2_peptides (length : int , count : int , temperature : float = 1.0 , top_k : int = 950 , top_p : float = 0.9 , repetition_penalty : float = 1.2 ) -> List [str ]:
45
+ def generate_llm_peptides (length : int , count : int , model_name : str = "protgpt2" , temperature : float = 1.0 , top_k : int = 950 , top_p : float = 0.9 , repetition_penalty : float = 1.2 ) -> List [str ]:
46
46
try :
47
- from transformers import pipeline
47
+ from transformers import pipeline , AutoTokenizer , AutoModel , AutoModelForMaskedLM
48
+ import torch
49
+ import torch .nn .functional as F
48
50
except ImportError :
49
- print ("Error: transformers package is required for ProtGPT2 generation. Please install with 'pip install transformers torch'" , file = sys .stderr )
51
+ print ("Error: transformers and torch packages are required for LLM generation. Please install with 'pip install transformers torch'" , file = sys .stderr )
50
52
sys .exit (1 )
51
- # Each token is ~4 amino acids, so for a peptide of length N, set max_length ≈ N/4 (rounded up)
52
- max_length = max (5 , (length + 3 ) // 4 ) # ensure at least 1 token
53
- protgpt2 = pipeline ('text-generation' , model = "nferruz/ProtGPT2" , framework = "pt" )
53
+ # Model configurations
54
+ if model_name .lower () == "protgpt2" :
55
+ model_id = "nferruz/ProtGPT2"
56
+ # Each token is ~4 amino acids, so for a peptide of length N, set max_length ≈ N/4 (rounded up)
57
+ max_length = max (5 , (length + 3 ) // 4 ) # ensure at least 1 token
58
+ prompt = "<|endoftext|>"
59
+ use_pipeline = True
60
+ elif model_name .lower () == "esm2" :
61
+ model_id = "facebook/esm2_t12_35M_UR50D" # ESM-2 35M model (faster, smaller)
62
+ max_length = length + 10 # ESM works with direct amino acid sequences
63
+ use_pipeline = False
64
+ else :
65
+ print (f"Error: Unsupported model '{ model_name } '. Supported models: protgpt2, esm2" , file = sys .stderr )
66
+ sys .exit (1 )
67
+
54
68
peptides = []
55
69
tries = 0
56
- batch_size = min (50 , count ) # Generate more peptides per batch for efficiency
57
- while len (peptides ) < count and tries < count * 10 :
58
- sequences = protgpt2 ("<|endoftext|>" , max_length = max_length , do_sample = True , top_k = top_k , top_p = top_p , temperature = temperature , repetition_penalty = repetition_penalty , num_return_sequences = min (count - len (peptides ), batch_size ), eos_token_id = 0 )
59
- if not sequences or not hasattr (sequences , '__iter__' ):
60
- tries += 1
61
- continue
62
- for seq in sequences :
63
- if not isinstance (seq , dict ):
64
- continue
65
- gen_text = seq .get ('generated_text' , '' )
66
- if not isinstance (gen_text , str ):
70
+ batch_size = min (50 , count ) if model_name .lower () == "protgpt2" else min (10 , count )
71
+
72
+ if use_pipeline :
73
+ # ProtGPT2 pipeline approach
74
+ llm_pipeline = pipeline ('text-generation' , model = model_id , framework = "pt" )
75
+ while len (peptides ) < count and tries < count * 10 :
76
+ sequences = llm_pipeline (prompt , max_length = max_length , do_sample = True , top_k = top_k , top_p = top_p , temperature = temperature , repetition_penalty = repetition_penalty , num_return_sequences = min (count - len (peptides ), batch_size ), eos_token_id = 0 )
77
+ if not sequences or not hasattr (sequences , '__iter__' ):
78
+ tries += 1
67
79
continue
68
- # Remove whitespace and newlines, keep only valid amino acids
69
- pep = '' .join ([c for c in gen_text if c in AMINO_ACIDS ])
70
- if len (pep ) == length :
71
- peptides .append (pep )
72
- tries += 1
80
+ for seq in sequences :
81
+ if not isinstance (seq , dict ):
82
+ continue
83
+ gen_text = seq .get ('generated_text' , '' )
84
+ if not isinstance (gen_text , str ):
85
+ continue
86
+ # Remove whitespace and newlines, keep only valid amino acids
87
+ pep = '' .join ([c for c in gen_text if c in AMINO_ACIDS ])
88
+ if len (pep ) == length :
89
+ peptides .append (pep )
90
+ tries += 1
91
+ else :
92
+ # ESM-2 approach using masked language modeling
93
+ print (f"Loading ESM-2 model (this may take a moment)..." , file = sys .stderr )
94
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
95
+ model = AutoModelForMaskedLM .from_pretrained (model_id )
96
+
97
+ while len (peptides ) < count and tries < count * 10 : # Reduce tries for efficiency
98
+ try :
99
+ for _ in range (min (count - len (peptides ), batch_size )):
100
+ # Simple approach: create a sequence with random masks
101
+ # Start with a random seed and mask 40% of positions
102
+ seed_length = max (1 , length // 3 )
103
+ seed_seq = "" .join (random .choices (AMINO_ACIDS , k = seed_length ))
104
+
105
+ # Create masked sequence: seed + masks for remaining positions
106
+ remaining_masks = length - seed_length
107
+ masked_sequence = seed_seq + "<mask>" * remaining_masks
108
+
109
+ # Tokenize
110
+ inputs = tokenizer (masked_sequence , return_tensors = "pt" , max_length = 512 , truncation = True )
111
+
112
+ with torch .no_grad ():
113
+ outputs = model (** inputs )
114
+ predictions = outputs .logits
115
+
116
+ # Apply temperature
117
+ predictions = predictions / temperature
118
+
119
+ # Find mask positions and generate amino acids
120
+ mask_token_id = tokenizer .mask_token_id
121
+ input_ids = inputs .input_ids [0 ]
122
+
123
+ generated_sequence = []
124
+ token_idx = 0
125
+
126
+ for token_id in input_ids :
127
+ if token_id == mask_token_id :
128
+ # Get prediction for this mask
129
+ logits = predictions [0 , token_idx ]
130
+
131
+ # Simple top-k sampling
132
+ if top_k > 0 and top_k < logits .size (- 1 ):
133
+ top_k_logits , top_k_indices = torch .topk (logits , top_k )
134
+ # Sample from top-k
135
+ probs = F .softmax (top_k_logits , dim = - 1 )
136
+ sampled_idx = torch .multinomial (probs , 1 ).item ()
137
+ sampled_token_id = top_k_indices [sampled_idx ].item ()
138
+ else :
139
+ # Sample from full distribution
140
+ probs = F .softmax (logits , dim = - 1 )
141
+ sampled_token_id = torch .multinomial (probs , 1 ).item ()
142
+
143
+ sampled_token = tokenizer .decode ([sampled_token_id ])
144
+
145
+ # Only add valid amino acids
146
+ if sampled_token in AMINO_ACIDS :
147
+ generated_sequence .append (sampled_token )
148
+ else :
149
+ # Fallback to random amino acid
150
+ generated_sequence .append (random .choice (AMINO_ACIDS ))
151
+ else :
152
+ # Keep original token if it's an amino acid
153
+ original_token = tokenizer .decode ([token_id ])
154
+ if original_token in AMINO_ACIDS :
155
+ generated_sequence .append (original_token )
156
+
157
+ token_idx += 1
158
+
159
+ # Create final peptide
160
+ final_peptide = "" .join (generated_sequence )
161
+
162
+ # Ensure exact length
163
+ if len (final_peptide ) >= length :
164
+ final_peptide = final_peptide [:length ]
165
+ elif len (final_peptide ) < length :
166
+ # Pad with random amino acids if too short
167
+ final_peptide += "" .join (random .choices (AMINO_ACIDS , k = length - len (final_peptide )))
168
+
169
+ if len (final_peptide ) == length :
170
+ peptides .append (final_peptide )
171
+
172
+ except Exception as e :
173
+ print (f"Warning: ESM-2 generation error: { e } " , file = sys .stderr )
174
+ # Fallback to random generation for this attempt
175
+ fallback_peptide = "" .join (random .choices (AMINO_ACIDS , k = length ))
176
+ peptides .append (fallback_peptide )
177
+
178
+ tries += 1
73
179
if len (peptides ) < count :
74
180
print (f"Warning: Only generated { len (peptides )} peptides of requested { count } with exact length { length } ." , file = sys .stderr )
75
181
return peptides [:count ]
@@ -83,14 +189,15 @@ def main():
83
189
parser = argparse .ArgumentParser (description = "Generate control peptides for neoantigen analysis." )
84
190
parser .add_argument ('--length' , type = int , required = True , help = 'Peptide length (e.g., 8, 9, 10)' )
85
191
parser .add_argument ('--count' , type = int , required = True , help = 'Number of peptides to generate' )
86
- parser .add_argument ('--source' , choices = ['random' , 'fasta' , 'protgpt2' ], required = True , help = 'Source of peptides: random, fasta, or protgpt2' )
192
+ parser .add_argument ('--source' , choices = ['random' , 'fasta' , 'llm' ], required = True , help = 'Source of peptides: random, fasta, or llm' )
193
+ parser .add_argument ('--llm_model' , choices = ['protgpt2' , 'esm2' ], default = 'protgpt2' , help = 'LLM model to use for generation (protgpt2 or esm2)' )
87
194
parser .add_argument ('--fasta_file' , type = Path , help = 'Path to reference FASTA file (required if source is fasta)' )
88
195
parser .add_argument ('--output' , type = Path , default = Path ('control_peptides.fasta' ), help = 'Output FASTA file' )
89
- parser .add_argument ('--seed' , type = int , help = 'Random seed for reproducibility (not used for protgpt2 )' )
90
- parser .add_argument ('--temperature' , type = float , default = 1.0 , help = 'Temperature for ProtGPT2 generation (higher = more random)' )
91
- parser .add_argument ('--top_k' , type = int , default = 950 , help = 'Top-k sampling for ProtGPT2 ' )
92
- parser .add_argument ('--top_p' , type = float , default = 0.9 , help = 'Top-p (nucleus) sampling for ProtGPT2 ' )
93
- parser .add_argument ('--repetition_penalty' , type = float , default = 1.2 , help = 'Repetition penalty for ProtGPT2 ' )
196
+ parser .add_argument ('--seed' , type = int , help = 'Random seed for reproducibility (not used for llm models )' )
197
+ parser .add_argument ('--temperature' , type = float , default = 1.0 , help = 'Temperature for LLM generation (higher = more random)' )
198
+ parser .add_argument ('--top_k' , type = int , default = 950 , help = 'Top-k sampling for LLM ' )
199
+ parser .add_argument ('--top_p' , type = float , default = 0.9 , help = 'Top-p (nucleus) sampling for LLM ' )
200
+ parser .add_argument ('--repetition_penalty' , type = float , default = 1.2 , help = 'Repetition penalty for LLM ' )
94
201
args = parser .parse_args ()
95
202
96
203
# Input validation
@@ -100,7 +207,7 @@ def main():
100
207
if args .count < 1 or args .count > 10000000 :
101
208
print ('Error: Count must be between 1 and 10,000,000 peptides' , file = sys .stderr )
102
209
sys .exit (1 )
103
- if args .source == 'protgpt2 ' :
210
+ if args .source == 'llm ' :
104
211
if args .temperature <= 0 :
105
212
print ('Error: Temperature must be positive' , file = sys .stderr )
106
213
sys .exit (1 )
@@ -114,7 +221,7 @@ def main():
114
221
print ('Error: Repetition penalty must be positive' , file = sys .stderr )
115
222
sys .exit (1 )
116
223
117
- if args .seed is not None and args .source != 'protgpt2 ' :
224
+ if args .seed is not None and args .source != 'llm ' :
118
225
random .seed (args .seed )
119
226
120
227
if args .source == 'random' :
@@ -124,8 +231,8 @@ def main():
124
231
print ('Error: --fasta_file is required when source is fasta' , file = sys .stderr )
125
232
sys .exit (1 )
126
233
peptides = sample_peptides_from_fasta (args .fasta_file , args .length , args .count )
127
- elif args .source == 'protgpt2 ' :
128
- peptides = generate_protgpt2_peptides (args .length , args .count , args .temperature , args .top_k , args .top_p , args .repetition_penalty )
234
+ elif args .source == 'llm ' :
235
+ peptides = generate_llm_peptides (args .length , args .count , args . llm_model , args .temperature , args .top_k , args .top_p , args .repetition_penalty )
129
236
else :
130
237
print (f"Unknown source: { args .source } " , file = sys .stderr )
131
238
sys .exit (1 )
0 commit comments