Skip to content

Commit b841914

Browse files
committed
Enhance peptide generator with improved functionality
- Add requirements.txt with correct dependencies - Add comprehensive input validation for all parameters - Add ProtGPT2 parameters (temperature, top_k, top_p, repetition_penalty) - Optimize ProtGPT2 batch generation for better performance - Improve error handling and user feedback
1 parent 13261d4 commit b841914

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

generate_control_peptides.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def sample_peptides_from_fasta(fasta_path: Path, length: int, count: int) -> Lis
4242
peptides.append(random.choice(all_subseqs))
4343
return peptides[:count]
4444

45-
def generate_protgpt2_peptides(length: int, count: int) -> List[str]:
45+
def generate_protgpt2_peptides(length: int, count: int, temperature: float = 1.0, top_k: int = 950, top_p: float = 0.9, repetition_penalty: float = 1.2) -> List[str]:
4646
try:
4747
from transformers import pipeline
4848
except ImportError:
@@ -53,8 +53,9 @@ def generate_protgpt2_peptides(length: int, count: int) -> List[str]:
5353
protgpt2 = pipeline('text-generation', model="nferruz/ProtGPT2", framework="pt")
5454
peptides = []
5555
tries = 0
56+
batch_size = min(50, count) # Generate more peptides per batch for efficiency
5657
while len(peptides) < count and tries < count * 10:
57-
sequences = protgpt2("<|endoftext|>", max_length=max_length, do_sample=True, top_k=950, repetition_penalty=1.2, num_return_sequences=min(count - len(peptides), 10), eos_token_id=0)
58+
sequences = protgpt2("<|endoftext|>", max_length=max_length, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, num_return_sequences=min(count - len(peptides), batch_size), eos_token_id=0)
5859
if not sequences or not hasattr(sequences, '__iter__'):
5960
tries += 1
6061
continue
@@ -86,8 +87,33 @@ def main():
8687
parser.add_argument('--fasta_file', type=Path, help='Path to reference FASTA file (required if source is fasta)')
8788
parser.add_argument('--output', type=Path, default=Path('control_peptides.fasta'), help='Output FASTA file')
8889
parser.add_argument('--seed', type=int, help='Random seed for reproducibility (not used for protgpt2)')
90+
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for ProtGPT2 generation (higher = more random)')
91+
parser.add_argument('--top_k', type=int, default=950, help='Top-k sampling for ProtGPT2')
92+
parser.add_argument('--top_p', type=float, default=0.9, help='Top-p (nucleus) sampling for ProtGPT2')
93+
parser.add_argument('--repetition_penalty', type=float, default=1.2, help='Repetition penalty for ProtGPT2')
8994
args = parser.parse_args()
9095

96+
# Input validation
97+
if args.length < 1 or args.length > 50:
98+
print('Error: Peptide length must be between 1 and 50 amino acids', file=sys.stderr)
99+
sys.exit(1)
100+
if args.count < 1 or args.count > 10000000:
101+
print('Error: Count must be between 1 and 10,000,000 peptides', file=sys.stderr)
102+
sys.exit(1)
103+
if args.source == 'protgpt2':
104+
if args.temperature <= 0:
105+
print('Error: Temperature must be positive', file=sys.stderr)
106+
sys.exit(1)
107+
if args.top_k < 1:
108+
print('Error: Top-k must be at least 1', file=sys.stderr)
109+
sys.exit(1)
110+
if args.top_p <= 0 or args.top_p > 1:
111+
print('Error: Top-p must be between 0 and 1', file=sys.stderr)
112+
sys.exit(1)
113+
if args.repetition_penalty <= 0:
114+
print('Error: Repetition penalty must be positive', file=sys.stderr)
115+
sys.exit(1)
116+
91117
if args.seed is not None and args.source != 'protgpt2':
92118
random.seed(args.seed)
93119

@@ -99,7 +125,7 @@ def main():
99125
sys.exit(1)
100126
peptides = sample_peptides_from_fasta(args.fasta_file, args.length, args.count)
101127
elif args.source == 'protgpt2':
102-
peptides = generate_protgpt2_peptides(args.length, args.count)
128+
peptides = generate_protgpt2_peptides(args.length, args.count, args.temperature, args.top_k, args.top_p, args.repetition_penalty)
103129
else:
104130
print(f"Unknown source: {args.source}", file=sys.stderr)
105131
sys.exit(1)

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
PySimpleGUI>=4.60.0
2+
transformers>=4.20.0
3+
torch>=1.12.0

0 commit comments

Comments
 (0)