Skip to content

Commit 1e4f3f1

Browse files
author
ntnn19
committed
Merge branch 'fix-pull-request-sokrypton#633' of https://github.com/ntnn19/ColabFold into fix-pull-request-sokrypton#633
2 parents c0d371d + d7726ea commit 1e4f3f1

File tree

3 files changed

+132
-93
lines changed

3 files changed

+132
-93
lines changed

colabfold/mmseqs/search.py

Lines changed: 126 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""
22
Functionality for running mmseqs locally. Takes in a fasta file, outputs final.a3m
3-
4-
Note: Currently needs mmseqs compiled from source
53
"""
64

75
import logging
@@ -18,8 +16,28 @@
1816

1917
logger = logging.getLogger(__name__)
2018

19+
MODULE_OUTPUT_POS = {
20+
"align": 4,
21+
"convertalis": 4,
22+
"expandaln": 5,
23+
"filterresult": 4,
24+
"lndb": 2,
25+
"mergedbs": 2,
26+
"mvdb": 2,
27+
"pairaln": 4,
28+
"result2msa": 4,
29+
"search": 3,
30+
}
2131

2232
def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]):
33+
module = params[0]
34+
if module in MODULE_OUTPUT_POS:
35+
output_pos = MODULE_OUTPUT_POS[module]
36+
output_path = Path(params[output_pos]).with_suffix('.dbtype')
37+
if output_path.exists():
38+
logger.info(f"Skipping {module} because {output_path} already exists")
39+
return
40+
2341
params_log = " ".join(str(i) for i in params)
2442
logger.info(f"Running {mmseqs} {params_log}")
2543
# hide MMseqs2 verbose paramters list that clogs up the log
@@ -46,6 +64,7 @@ def mmseqs_search_monomer(
4664
s: float = 8,
4765
db_load_mode: int = 2,
4866
threads: int = 32,
67+
unpack: bool = True,
4968
):
5069
"""Run mmseqs with a local colabfold database set
5170
@@ -86,8 +105,6 @@ def mmseqs_search_monomer(
86105
dbSuffix2 = ".idx"
87106
dbSuffix3 = ".idx"
88107

89-
# fmt: off
90-
# @formatter:off
91108
search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000"]
92109
search_param += ["--prefilter-mode", str(prefilter_mode)]
93110
if s is not None:
@@ -98,24 +115,27 @@ def mmseqs_search_monomer(
98115
filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",]
99116
expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), "--max-seq-id", "0.95",]
100117

101-
run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param)
102-
run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")])
103-
run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")])
104-
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param)
105-
run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"])
106-
run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
107-
base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode",
108-
str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads",
109-
str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"])
110-
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
111-
base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode",
112-
"6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param)
113-
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign")])
114-
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp")])
115-
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res")])
116-
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign_filter")])
118+
if not base.joinpath("uniref.a3m").with_suffix('.a3m.dbtype').exists():
119+
run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param)
120+
run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")])
121+
run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")])
122+
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param)
123+
run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"])
124+
run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
125+
base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode",
126+
str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads",
127+
str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"])
128+
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
129+
base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode",
130+
"6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param)
131+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_filter")])
132+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")])
133+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")])
134+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
135+
else:
136+
logger.info(f"Skipping {uniref_db} search because uniref.a3m already exists")
117137

118-
if use_env:
138+
if use_env and not base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m").with_suffix('.a3m.dbtype').exists():
119139
run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(metagenomic_db), base.joinpath("res_env"),
120140
base.joinpath("tmp3"), "--threads", str(threads)] + search_param)
121141
run_mmseqs(mmseqs, ["expandaln", base.joinpath("prof_res"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), base.joinpath("res_env"),
@@ -133,45 +153,49 @@ def mmseqs_search_monomer(
133153
base.joinpath("res_env_exp_realign_filter"),
134154
base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m"), "--msa-format-mode", "6",
135155
"--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param)
136-
137156
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign_filter")])
138157
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign")])
139158
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp")])
140159
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env")])
160+
elif use_env:
161+
logger.info(f"Skipping {metagenomic_db} search because bfd.mgnify30.metaeuk30.smag30.a3m already exists")
141162

142-
run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
143-
run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
144-
else:
145-
run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")])
146-
147-
if use_templates:
163+
if use_templates and not base.joinpath("res_pdb.m8").with_suffix('.m8.dbtype').exists():
148164
run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(template_db), base.joinpath("res_pdb"),
149165
base.joinpath("tmp2"), "--db-load-mode", str(db_load_mode), "--threads", str(threads), "-s", "7.5", "-a", "-e", "0.1", "--prefilter-mode", str(prefilter_mode)])
150166
run_mmseqs(mmseqs, ["convertalis", base.joinpath("prof_res"), dbbase.joinpath(f"{template_db}{dbSuffix3}"), base.joinpath("res_pdb"),
151-
base.joinpath(f"{template_db}"), "--format-output",
167+
base.joinpath("res_pdb.m8"), "--format-output",
152168
"query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar",
153169
"--db-output", "1",
154170
"--db-load-mode", str(db_load_mode), "--threads", str(threads)])
155-
run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db}"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"])
156171
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb")])
157-
run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"{template_db}")])
172+
elif use_templates:
173+
logger.info(f"Skipping {template_db} search because res_pdb.m8 already exists")
158174

159-
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"])
160-
run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")])
161-
run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")])
162-
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
163-
# @formatter:on
164-
# fmt: on
175+
if use_env:
176+
run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
177+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
178+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")])
179+
else:
180+
run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")])
181+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")])
165182

166-
for file in base.glob("prof_res*"):
167-
file.unlink()
183+
if unpack:
184+
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"])
185+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")])
186+
187+
if use_templates:
188+
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("res_pdb.m8"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"])
189+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb.m8")])
190+
191+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res")])
192+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res_h")])
168193
shutil.rmtree(base.joinpath("tmp"))
169194
if use_templates:
170195
shutil.rmtree(base.joinpath("tmp2"))
171196
if use_env:
172197
shutil.rmtree(base.joinpath("tmp3"))
173198

174-
175199
def mmseqs_search_pair(
176200
dbbase: Path,
177201
base: Path,
@@ -184,6 +208,7 @@ def mmseqs_search_pair(
184208
threads: int = 64,
185209
db_load_mode: int = 2,
186210
pairing_strategy: int = 0,
211+
unpack: bool = True,
187212
):
188213
if not dbbase.joinpath(f"{uniref_db}.dbtype").is_file():
189214
raise FileNotFoundError(f"Database {uniref_db} does not exist")
@@ -225,14 +250,15 @@ def mmseqs_search_pair(
225250
run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],)
226251
run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],)
227252
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],)
228-
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],)
253+
if unpack:
254+
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],)
255+
run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")])
229256
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
230257
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")])
231258
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")])
232259
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair")])
233260
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair_bt")])
234261
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_final")])
235-
run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")])
236262
shutil.rmtree(base.joinpath("tmp"))
237263
# @formatter:on
238264
# fmt: on
@@ -340,6 +366,9 @@ def main():
340366
default=0,
341367
help="Database preload mode 0: auto, 1: fread, 2: mmap, 3: mmap+touch",
342368
)
369+
parser.add_argument(
370+
"--unpack", type=int, default=1, choices=[0, 1], help="Unpack results to loose files or keep MMseqs2 databases."
371+
)
343372
parser.add_argument(
344373
"--threads", type=int, default=64, help="Number of threads to use."
345374
)
@@ -416,6 +445,7 @@ def main():
416445
s=args.s,
417446
db_load_mode=args.db_load_mode,
418447
threads=args.threads,
448+
unpack=args.unpack,
419449
)
420450
if is_complex is True:
421451
mmseqs_search_pair(
@@ -429,6 +459,7 @@ def main():
429459
threads=args.threads,
430460
pairing_strategy=args.pairing_strategy,
431461
pair_env=False,
462+
unpack=args.unpack,
432463
)
433464
if args.use_env_pairing:
434465
mmseqs_search_pair(
@@ -443,63 +474,66 @@ def main():
443474
threads=args.threads,
444475
pairing_strategy=args.pairing_strategy,
445476
pair_env=True,
477+
unpack=args.unpack,
446478
)
447479

448-
id = 0
449-
for job_number, (
450-
raw_jobname,
451-
query_sequences,
452-
query_seqs_cardinality,
453-
) in enumerate(queries_unique):
454-
unpaired_msa = []
455-
paired_msa = None
456-
if len(query_seqs_cardinality) > 1:
457-
paired_msa = []
458-
for seq in query_sequences:
459-
with args.base.joinpath(f"{id}.a3m").open("r") as f:
460-
unpaired_msa.append(f.read())
461-
args.base.joinpath(f"{id}.a3m").unlink()
462-
463-
if args.use_env_pairing:
464-
with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair:
465-
with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env:
466-
while chunk := file_pair_env.read(10 * 1024 * 1024):
467-
file_pair.write(chunk)
468-
args.base.joinpath(f"{id}.env.paired.a3m").unlink()
469-
480+
if args.unpack:
481+
id = 0
482+
for job_number, (
483+
raw_jobname,
484+
query_sequences,
485+
query_seqs_cardinality,
486+
) in enumerate(queries_unique):
487+
unpaired_msa = []
488+
paired_msa = None
470489
if len(query_seqs_cardinality) > 1:
471-
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
472-
paired_msa.append(f.read())
473-
args.base.joinpath(f"{id}.paired.a3m").unlink()
474-
id += 1
475-
msa = msa_to_str(
476-
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
490+
paired_msa = []
491+
for seq in query_sequences:
492+
with args.base.joinpath(f"{id}.a3m").open("r") as f:
493+
unpaired_msa.append(f.read())
494+
args.base.joinpath(f"{id}.a3m").unlink()
495+
496+
if args.use_env_pairing:
497+
with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair:
498+
with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env:
499+
while chunk := file_pair_env.read(10 * 1024 * 1024):
500+
file_pair.write(chunk)
501+
args.base.joinpath(f"{id}.env.paired.a3m").unlink()
502+
503+
if len(query_seqs_cardinality) > 1:
504+
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
505+
paired_msa.append(f.read())
506+
args.base.joinpath(f"{id}.paired.a3m").unlink()
507+
id += 1
508+
msa = msa_to_str(
509+
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
510+
)
511+
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
512+
513+
if args.unpack:
514+
# rename a3m files
515+
for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique):
516+
os.rename(
517+
args.base.joinpath(f"{job_number}.a3m"),
518+
args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"),
477519
)
478-
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
479-
480-
# rename a3m files
481-
for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique):
482-
os.rename(
483-
args.base.joinpath(f"{job_number}.a3m"),
484-
args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"),
485-
)
486520

487-
# rename m8 files
488-
if args.use_templates:
489-
id = 0
490-
for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique:
491-
with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open(
492-
"w"
493-
) as f:
494-
for _ in range(len(query_seqs_cardinality)):
495-
with args.base.joinpath(f"{id}.m8").open("r") as g:
496-
f.write(g.read())
497-
os.remove(args.base.joinpath(f"{id}.m8"))
498-
id += 1
521+
# rename m8 files
522+
if args.use_templates:
523+
id = 0
524+
for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique:
525+
with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open(
526+
"w"
527+
) as f:
528+
for _ in range(len(query_seqs_cardinality)):
529+
with args.base.joinpath(f"{id}.m8").open("r") as g:
530+
f.write(g.read())
531+
os.remove(args.base.joinpath(f"{id}.m8"))
532+
id += 1
533+
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
534+
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])
499535

500536
query_file.unlink()
501-
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
502-
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])
503537

504538

505539
if __name__ == "__main__":

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ classifiers = [
1919
]
2020

2121
[tool.poetry.dependencies]
22-
python = ">=3.9,<3.12"
22+
python = ">=3.9"
2323
absl-py = "^1.0.0"
2424
jax = { version = "^0.4.20", optional = true }
2525
matplotlib = "^3.2.2"

tests/mock.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def mock_run_mmseqs2(
159159
"pairing_strategy": pairing_strategy,
160160
}
161161

162+
# make pre env-pair test work again, this was always true previously
163+
# however didn't do anything
164+
if len(query) > 1:
165+
config["use_env"] = True
166+
162167
for saved_response in self.saved_responses:
163168
# backwards compatibility, remove after UPDATE_SNAPSHOTS
164169
if "pairing_strategy" not in saved_response["config"]:

0 commit comments

Comments
 (0)