@@ -39,19 +39,22 @@ def parse_fasta_sequences(fasta_path: Path) -> List[str]:
39
39
40
40
def sample_peptides_from_fasta (fasta_path : Path , length : int , count : int ) -> List [str ]:
41
41
sequences = parse_fasta_sequences (fasta_path )
42
- all_subseqs = []
42
+ all_subseqs = set () # Use set to automatically collapse duplicates
43
43
for seq in sequences :
44
44
if len (seq ) >= length :
45
45
for i in range (len (seq ) - length + 1 ):
46
- all_subseqs .append (seq [i :i + length ])
46
+ all_subseqs .add (seq [i :i + length ])
47
47
if not all_subseqs :
48
48
raise ValueError (f"No subsequences of length { length } found in { fasta_path } " )
49
- peptides = random .sample (all_subseqs , k = min (count , len (all_subseqs )))
49
+
50
+ # Convert set back to list for sampling
51
+ unique_subseqs = list (all_subseqs )
52
+ peptides = random .sample (unique_subseqs , k = min (count , len (unique_subseqs )))
50
53
while len (peptides ) < count :
51
- peptides .append (random .choice (all_subseqs ))
54
+ peptides .append (random .choice (unique_subseqs ))
52
55
return peptides [:count ]
53
56
54
- def generate_llm_peptides (length : int , count : int , model_name : str = "protgpt2" , temperature : float = 1.0 , top_k : int = 950 , top_p : float = 0.9 , repetition_penalty : float = 1.2 ) -> List [str ]:
57
+ def generate_llm_peptides (length : int , count : int , model_name : str = "protgpt2" , top_k : int = 950 , top_p : float = 0.9 , repetition_penalty : float = 1.2 ) -> List [str ]:
55
58
try :
56
59
from transformers import pipeline , AutoTokenizer , AutoModel , AutoModelForMaskedLM
57
60
import torch
@@ -82,7 +85,7 @@ def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2",
82
85
# ProtGPT2 pipeline approach
83
86
llm_pipeline = pipeline ('text-generation' , model = model_id , framework = "pt" )
84
87
while len (peptides ) < count and tries < count * 10 :
85
- sequences = llm_pipeline (prompt , max_length = max_length , do_sample = True , top_k = top_k , top_p = top_p , temperature = temperature , repetition_penalty = repetition_penalty , num_return_sequences = min (count - len (peptides ), batch_size ), eos_token_id = 0 )
88
+ sequences = llm_pipeline (prompt , max_length = max_length , do_sample = True , top_k = top_k , top_p = top_p , temperature = 1.0 , repetition_penalty = repetition_penalty , num_return_sequences = min (count - len (peptides ), batch_size ), eos_token_id = 0 )
86
89
if not sequences or not hasattr (sequences , '__iter__' ):
87
90
tries += 1
88
91
continue
@@ -125,8 +128,8 @@ def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2",
125
128
outputs = model (** inputs )
126
129
predictions = outputs .logits
127
130
128
- # Apply temperature
129
- predictions = predictions / temperature
131
+ # Apply fixed temperature
132
+ predictions = predictions / 1.0
130
133
131
134
# Find mask positions and generate amino acids
132
135
mask_token_id = tokenizer .mask_token_id
@@ -208,7 +211,7 @@ def main():
208
211
parser .add_argument ('--fasta_file' , type = Path , help = 'Path to reference FASTA file (required if source is fasta)' )
209
212
parser .add_argument ('--output' , type = Path , default = Path ('control_peptides.fasta' ), help = 'Output FASTA file' )
210
213
parser .add_argument ('--seed' , type = int , help = 'Random seed for reproducibility (not used for llm models)' )
211
- parser . add_argument ( '--temperature' , type = float , default = 1.0 , help = 'Temperature for LLM generation (higher = more random)' )
214
+
212
215
parser .add_argument ('--top_k' , type = int , default = 950 , help = 'Top-k sampling for LLM' )
213
216
parser .add_argument ('--top_p' , type = float , default = 0.9 , help = 'Top-p (nucleus) sampling for LLM' )
214
217
parser .add_argument ('--repetition_penalty' , type = float , default = 1.2 , help = 'Repetition penalty for LLM' )
@@ -222,9 +225,7 @@ def main():
222
225
print ('Error: Count must be between 1 and 10,000,000 peptides' , file = sys .stderr )
223
226
sys .exit (1 )
224
227
if args .source == 'llm' :
225
- if args .temperature <= 0 :
226
- print ('Error: Temperature must be positive' , file = sys .stderr )
227
- sys .exit (1 )
228
+
228
229
if args .top_k < 1 :
229
230
print ('Error: Top-k must be at least 1' , file = sys .stderr )
230
231
sys .exit (1 )
@@ -246,7 +247,7 @@ def main():
246
247
sys .exit (1 )
247
248
peptides = sample_peptides_from_fasta (args .fasta_file , args .length , args .count )
248
249
elif args .source == 'llm' :
249
- peptides = generate_llm_peptides (args .length , args .count , args .llm_model , args .temperature , args . top_k , args .top_p , args .repetition_penalty )
250
+ peptides = generate_llm_peptides (args .length , args .count , args .llm_model , args .top_k , args .top_p , args .repetition_penalty )
250
251
else :
251
252
print (f"Unknown source: { args .source } " , file = sys .stderr )
252
253
sys .exit (1 )
0 commit comments