@@ -42,7 +42,7 @@ def sample_peptides_from_fasta(fasta_path: Path, length: int, count: int) -> Lis
42
42
peptides .append (random .choice (all_subseqs ))
43
43
return peptides [:count ]
44
44
45
- def generate_protgpt2_peptides (length : int , count : int ) -> List [str ]:
45
+ def generate_protgpt2_peptides (length : int , count : int , temperature : float = 1.0 , top_k : int = 950 , top_p : float = 0.9 , repetition_penalty : float = 1.2 ) -> List [str ]:
46
46
try :
47
47
from transformers import pipeline
48
48
except ImportError :
@@ -53,8 +53,9 @@ def generate_protgpt2_peptides(length: int, count: int) -> List[str]:
53
53
protgpt2 = pipeline ('text-generation' , model = "nferruz/ProtGPT2" , framework = "pt" )
54
54
peptides = []
55
55
tries = 0
56
+ batch_size = min (50 , count ) # Generate more peptides per batch for efficiency
56
57
while len (peptides ) < count and tries < count * 10 :
57
- sequences = protgpt2 ("<|endoftext|>" , max_length = max_length , do_sample = True , top_k = 950 , repetition_penalty = 1.2 , num_return_sequences = min (count - len (peptides ), 10 ), eos_token_id = 0 )
58
+ sequences = protgpt2 ("<|endoftext|>" , max_length = max_length , do_sample = True , top_k = top_k , top_p = top_p , temperature = temperature , repetition_penalty = repetition_penalty , num_return_sequences = min (count - len (peptides ), batch_size ), eos_token_id = 0 )
58
59
if not sequences or not hasattr (sequences , '__iter__' ):
59
60
tries += 1
60
61
continue
@@ -86,8 +87,33 @@ def main():
86
87
parser .add_argument ('--fasta_file' , type = Path , help = 'Path to reference FASTA file (required if source is fasta)' )
87
88
parser .add_argument ('--output' , type = Path , default = Path ('control_peptides.fasta' ), help = 'Output FASTA file' )
88
89
parser .add_argument ('--seed' , type = int , help = 'Random seed for reproducibility (not used for protgpt2)' )
90
+ parser .add_argument ('--temperature' , type = float , default = 1.0 , help = 'Temperature for ProtGPT2 generation (higher = more random)' )
91
+ parser .add_argument ('--top_k' , type = int , default = 950 , help = 'Top-k sampling for ProtGPT2' )
92
+ parser .add_argument ('--top_p' , type = float , default = 0.9 , help = 'Top-p (nucleus) sampling for ProtGPT2' )
93
+ parser .add_argument ('--repetition_penalty' , type = float , default = 1.2 , help = 'Repetition penalty for ProtGPT2' )
89
94
args = parser .parse_args ()
90
95
96
+ # Input validation
97
+ if args .length < 1 or args .length > 50 :
98
+ print ('Error: Peptide length must be between 1 and 50 amino acids' , file = sys .stderr )
99
+ sys .exit (1 )
100
+ if args .count < 1 or args .count > 10000000 :
101
+ print ('Error: Count must be between 1 and 10,000,000 peptides' , file = sys .stderr )
102
+ sys .exit (1 )
103
+ if args .source == 'protgpt2' :
104
+ if args .temperature <= 0 :
105
+ print ('Error: Temperature must be positive' , file = sys .stderr )
106
+ sys .exit (1 )
107
+ if args .top_k < 1 :
108
+ print ('Error: Top-k must be at least 1' , file = sys .stderr )
109
+ sys .exit (1 )
110
+ if args .top_p <= 0 or args .top_p > 1 :
111
+ print ('Error: Top-p must be between 0 and 1' , file = sys .stderr )
112
+ sys .exit (1 )
113
+ if args .repetition_penalty <= 0 :
114
+ print ('Error: Repetition penalty must be positive' , file = sys .stderr )
115
+ sys .exit (1 )
116
+
91
117
if args .seed is not None and args .source != 'protgpt2' :
92
118
random .seed (args .seed )
93
119
@@ -99,7 +125,7 @@ def main():
99
125
sys .exit (1 )
100
126
peptides = sample_peptides_from_fasta (args .fasta_file , args .length , args .count )
101
127
elif args .source == 'protgpt2' :
102
- peptides = generate_protgpt2_peptides (args .length , args .count )
128
+ peptides = generate_protgpt2_peptides (args .length , args .count , args . temperature , args . top_k , args . top_p , args . repetition_penalty )
103
129
else :
104
130
print (f"Unknown source: { args .source } " , file = sys .stderr )
105
131
sys .exit (1 )
0 commit comments