Skip to content

Commit f31593c

Browse files
committed
Implement PI requested changes: update Dockerfile, fix duplicate peptides, remove temperature option
- Change Dockerfile to end with CMD ['/bin/bash'] instead of ENTRYPOINT - Fix sample_peptides_from_fasta to collapse duplicate peptides using set() - Remove temperature parameter from LLM generation (fixed at 1.0) - Remove --temperature command-line argument and all related validation
1 parent 5b591ed commit f31593c

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

Dockerfile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,5 @@ RUN pip install --upgrade pip \
2727
# Copy rest of the project
2828
COPY . .
2929

30-
# Default entrypoint shows help for the generator script
31-
ENTRYPOINT ["python", "generate_control_peptides.py"]
32-
CMD ["--help"]
30+
# Default command opens bash shell
31+
CMD ["/bin/bash"]

scripts/generation/generate_control_peptides.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,22 @@ def parse_fasta_sequences(fasta_path: Path) -> List[str]:
3939

4040
def sample_peptides_from_fasta(fasta_path: Path, length: int, count: int) -> List[str]:
4141
sequences = parse_fasta_sequences(fasta_path)
42-
all_subseqs = []
42+
all_subseqs = set() # Use set to automatically collapse duplicates
4343
for seq in sequences:
4444
if len(seq) >= length:
4545
for i in range(len(seq) - length + 1):
46-
all_subseqs.append(seq[i:i+length])
46+
all_subseqs.add(seq[i:i+length])
4747
if not all_subseqs:
4848
raise ValueError(f"No subsequences of length {length} found in {fasta_path}")
49-
peptides = random.sample(all_subseqs, k=min(count, len(all_subseqs)))
49+
50+
# Convert set back to list for sampling
51+
unique_subseqs = list(all_subseqs)
52+
peptides = random.sample(unique_subseqs, k=min(count, len(unique_subseqs)))
5053
while len(peptides) < count:
51-
peptides.append(random.choice(all_subseqs))
54+
peptides.append(random.choice(unique_subseqs))
5255
return peptides[:count]
5356

54-
def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2", temperature: float = 1.0, top_k: int = 950, top_p: float = 0.9, repetition_penalty: float = 1.2) -> List[str]:
57+
def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2", top_k: int = 950, top_p: float = 0.9, repetition_penalty: float = 1.2) -> List[str]:
5558
try:
5659
from transformers import pipeline, AutoTokenizer, AutoModel, AutoModelForMaskedLM
5760
import torch
@@ -82,7 +85,7 @@ def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2",
8285
# ProtGPT2 pipeline approach
8386
llm_pipeline = pipeline('text-generation', model=model_id, framework="pt")
8487
while len(peptides) < count and tries < count * 10:
85-
sequences = llm_pipeline(prompt, max_length=max_length, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, num_return_sequences=min(count - len(peptides), batch_size), eos_token_id=0)
88+
sequences = llm_pipeline(prompt, max_length=max_length, do_sample=True, top_k=top_k, top_p=top_p, temperature=1.0, repetition_penalty=repetition_penalty, num_return_sequences=min(count - len(peptides), batch_size), eos_token_id=0)
8689
if not sequences or not hasattr(sequences, '__iter__'):
8790
tries += 1
8891
continue
@@ -125,8 +128,8 @@ def generate_llm_peptides(length: int, count: int, model_name: str = "protgpt2",
125128
outputs = model(**inputs)
126129
predictions = outputs.logits
127130

128-
# Apply temperature
129-
predictions = predictions / temperature
131+
# Apply fixed temperature
132+
predictions = predictions / 1.0
130133

131134
# Find mask positions and generate amino acids
132135
mask_token_id = tokenizer.mask_token_id
@@ -208,7 +211,7 @@ def main():
208211
parser.add_argument('--fasta_file', type=Path, help='Path to reference FASTA file (required if source is fasta)')
209212
parser.add_argument('--output', type=Path, default=Path('control_peptides.fasta'), help='Output FASTA file')
210213
parser.add_argument('--seed', type=int, help='Random seed for reproducibility (not used for llm models)')
211-
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for LLM generation (higher = more random)')
214+
212215
parser.add_argument('--top_k', type=int, default=950, help='Top-k sampling for LLM')
213216
parser.add_argument('--top_p', type=float, default=0.9, help='Top-p (nucleus) sampling for LLM')
214217
parser.add_argument('--repetition_penalty', type=float, default=1.2, help='Repetition penalty for LLM')
@@ -222,9 +225,7 @@ def main():
222225
print('Error: Count must be between 1 and 10,000,000 peptides', file=sys.stderr)
223226
sys.exit(1)
224227
if args.source == 'llm':
225-
if args.temperature <= 0:
226-
print('Error: Temperature must be positive', file=sys.stderr)
227-
sys.exit(1)
228+
228229
if args.top_k < 1:
229230
print('Error: Top-k must be at least 1', file=sys.stderr)
230231
sys.exit(1)
@@ -246,7 +247,7 @@ def main():
246247
sys.exit(1)
247248
peptides = sample_peptides_from_fasta(args.fasta_file, args.length, args.count)
248249
elif args.source == 'llm':
249-
peptides = generate_llm_peptides(args.length, args.count, args.llm_model, args.temperature, args.top_k, args.top_p, args.repetition_penalty)
250+
peptides = generate_llm_peptides(args.length, args.count, args.llm_model, args.top_k, args.top_p, args.repetition_penalty)
250251
else:
251252
print(f"Unknown source: {args.source}", file=sys.stderr)
252253
sys.exit(1)

0 commit comments

Comments
 (0)