-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathgogpt_api.py
More file actions
155 lines (125 loc) · 4.15 KB
/
gogpt_api.py
File metadata and controls
155 lines (125 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python
"""
GO-GPT API: Predict Gene Ontology (GO) terms from a protein sequence.
Usage (CLI):
python gogpt_api.py --sequence <protein_sequence>
python gogpt_api.py --sequence <protein_sequence> --organism "Organism name"
Usage (Import):
from gogpt_api import load_predictor, predict_and_format
predictor = load_predictor()
result = predict_and_format(predictor, sequence, organism)
Example:
python gogpt_api.py --sequence "MVLSPADKTN..."
python gogpt_api.py --sequence "MVLSPADKTN..." --organism "Mus musculus"
"""
import argparse
import sys
from pathlib import Path
from typing import Dict, List, Optional
# Add paths for imports
REPO_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(REPO_ROOT / "gogpt" / "src"))
sys.path.insert(0, str(REPO_ROOT))
from gogpt import GOGPTPredictor
from bioreason2.dataset.cafa5.processor import _GO_INFO
def load_predictor(
model_name: str = "wanglab/gogpt",
cache_dir: Optional[str] = None
) -> GOGPTPredictor:
"""
Load GO-GPT predictor from HuggingFace Hub.
Args:
model_name: HuggingFace model name.
cache_dir: Cache directory for model weights.
Returns:
Loaded GOGPTPredictor instance.
"""
return GOGPTPredictor.from_pretrained(model_name, cache_dir=cache_dir)
def predict_go_terms(
predictor: GOGPTPredictor,
sequence: str,
organism: str = "Homo sapiens"
) -> Dict[str, List[str]]:
"""
Predict GO terms for a protein sequence.
Args:
predictor: Loaded GOGPTPredictor.
sequence: Protein sequence (amino acids).
organism: Organism name.
Returns:
Dict with keys "MF", "BP", "CC" containing lists of GO IDs.
"""
return predictor.predict(sequence=sequence, organism=organism)
def format_go_output(predictions: Dict[str, List[str]]) -> str:
"""
Format GO predictions as a human-readable string.
Args:
predictions: Dict with "MF", "BP", "CC" keys.
Returns:
Formatted string with GO terms and their names.
"""
aspect_names = {
"MF": "Molecular Function",
"BP": "Biological Process",
"CC": "Cellular Component"
}
parts = []
for aspect in ["MF", "BP", "CC"]:
terms = predictions.get(aspect, [])
if terms:
formatted = []
for go_id in terms:
name, _ = _GO_INFO.get(go_id, ("Unknown", ""))
formatted.append(f"{go_id} ({name})")
parts.append(f"{aspect_names[aspect]} ({aspect}): {', '.join(formatted)}")
else:
parts.append(f"{aspect_names[aspect]} ({aspect}): None")
return "\n".join(parts)
def predict_and_format(
predictor: GOGPTPredictor,
sequence: str,
organism: str = "Homo sapiens"
) -> str:
"""
Predict GO terms and format as human-readable string.
Args:
predictor: Loaded GOGPTPredictor.
sequence: Protein sequence (amino acids).
organism: Organism name.
Returns:
Formatted string with GO terms and their names.
"""
predictions = predict_go_terms(predictor, sequence, organism)
return format_go_output(predictions)
def main():
parser = argparse.ArgumentParser(
description="Predict GO terms from a protein sequence using GO-GPT."
)
parser.add_argument(
"--sequence",
required=True,
help="Protein sequence (amino acids)"
)
parser.add_argument(
"--organism",
default="Homo sapiens",
help="Organism name (default: 'Homo sapiens')"
)
parser.add_argument(
"--model",
default="wanglab/gogpt",
help="HuggingFace model name (default: wanglab/gogpt)"
)
parser.add_argument(
"--cache-dir",
default=None,
help="Cache directory for model weights (default: HuggingFace default)"
)
args = parser.parse_args()
# Load model
predictor = load_predictor(args.model, args.cache_dir)
# Run prediction and format
output = predict_and_format(predictor, args.sequence, args.organism)
print(output)
if __name__ == "__main__":
main()