Skip to content

Commit 85de708

Browse files
committed
Add Foundry integration and tests
1 parent b377562 commit 85de708

File tree

5 files changed

+111
-7
lines changed

5 files changed

+111
-7
lines changed

assets/pixi/rfdiffusion.toml

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,29 @@ description = "Pixi environment for https://github.com/RosettaCommons/rc RFdiffu
55
platforms = ["linux-64"]
66
#platforms = ["linux-64", "osx-64", "osx-arm64", "win-64"]
77

8-
channels = ["pytorch", "nvidia", "conda-forge", "dglteam"]
8+
channels = ["conda-forge"]
99

1010
[dependencies]
1111
python = "==3.9"
1212
pip = "*"
1313

1414
numpy = "1.*"
1515

16-
pytorch = { version = "==1.12.1", channel = "pytorch" }
17-
pytorch-cuda = { version = "==11.6", channel = "pytorch" }
18-
"dgl-cuda11.6" = "*"
16+
cudatoolkit = "11.6.*"
1917

2018
mkl = "<2024.1"
2119

20+
[pypi-options]
21+
index-url = "https://pypi.org/simple"
22+
extra-index-urls = ["https://download.pytorch.org/whl/cu116"]
23+
24+
# DGL uses a "find-links" wheel repo (equivalent of pip -f ...)
25+
find-links = [{ url = "https://data.dgl.ai/wheels/cu116/repo.html" }]
26+
2227
[pypi-dependencies]
28+
torch = "==1.12.1+cu116"
29+
dgl = "==1.0.2+cu116"
30+
2331
e3nn = "==0.3.3"
2432
wandb = "==0.12.0"
2533
pynvml = "==11.0.0"
@@ -28,6 +36,7 @@ decorator = "==5.1.0"
2836
hydra-core = "==1.3.2"
2937
pyrsistent = "==0.19.3"
3038

39+
3140
[tasks]
3241
setup = { depends-on = ["install", "download_weights"] }
3342

@@ -43,15 +52,19 @@ cmd = """
4352
( rm -rf rfdiffusion-repo-clone/models ; mkdir -p rfdiffusion-repo-clone/models ) \
4453
&& wget -O rfdiffusion-repo-clone/models/Base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt \
4554
&& wget -O rfdiffusion-repo-clone/models/Complex_base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt \
55+
&& wget -O rfdiffusion-repo-clone/models/Complex_beta_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt \
56+
\
4657
&& wget -O rfdiffusion-repo-clone/models/Complex_Fold_base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt \
4758
&& wget -O rfdiffusion-repo-clone/models/InpaintSeq_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt \
4859
&& wget -O rfdiffusion-repo-clone/models/InpaintSeq_Fold_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt \
4960
&& wget -O rfdiffusion-repo-clone/models/ActiveSite_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt \
5061
&& wget -O rfdiffusion-repo-clone/models/Base_epoch8_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt \
51-
&& wget -O rfdiffusion-repo-clone/models/Complex_beta_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt \
5262
&& wget -O rfdiffusion-repo-clone/models/RF_structure_prediction_weights.pt http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt \
5363
"""
5464

65+
[activation]
66+
env = { DGLBACKEND = "pytorch", LD_LIBRARY_PATH = "$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" }
67+
5568
[tasks.execute]
5669
args = ["args"]
5770
cmd = "python scripts/run_inference.py {{args}}"

src/app.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod foundry;
12
mod ligandmpnn;
23
mod picap;
34
mod proteinmpnn;
@@ -41,6 +42,10 @@ pub enum App {
4142
#[value(aliases = ["LigandMPNN"])]
4243
Ligandmpnn,
4344

45+
/// Run the Foundry command https://github.com/RosettaCommons/foundry
46+
#[value(aliases = ["Foundry"])]
47+
Foundry,
48+
4449
/// Run the PiCAP/CAPSIF2 command https://github.com/Graylab/picap
4550
#[value(aliases = ["PiCAP", "CAPSIF2"])]
4651
Picap,
@@ -174,6 +179,7 @@ impl App {
174179
App::Proteinmpnn => proteinmpnn::container_spec(app_args),
175180
App::ProteinmpnnScript => proteinmpnn_script::container_spec(app_args),
176181
App::Ligandmpnn => ligandmpnn::container_spec(app_args),
182+
App::Foundry => foundry::container_spec(app_args),
177183
App::Picap => picap::container_spec(app_args),
178184
}
179185
}
@@ -187,7 +193,8 @@ impl App {
187193
App::Proteinmpnn => todo!("not implemented"), // proteinmpnn::native_spec(app_args),
188194
App::ProteinmpnnScript => todo!("not implemented"), // proteinmpnn_script::native_spec(app_args),
189195
App::Ligandmpnn => todo!("not implemented"), // ligandmpnn::native_spec(app_args),
190-
App::Picap => todo!("not implemented"), // picap::native_spec(app_args),
196+
App::Foundry => foundry::native_spec(app_args, working_dir),
197+
App::Picap => todo!("not implemented"), // picap::native_spec(app_args),
191198
}
192199
}
193200
}

src/app/foundry.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use std::path::Path;
2+
3+
use crate::{
4+
app::{ContainerRunSpec, NativeRunSpec},
5+
util::include_asset,
6+
};
7+
8+
pub fn container_spec(app_args: Vec<String>) -> ContainerRunSpec {
9+
dbg!(&app_args);
10+
assert!(
11+
!app_args.is_empty() || !app_args[0].starts_with("-"),
12+
"Foundry arguments must include a protocol name as first argument"
13+
);
14+
15+
// match app_args[0].as_str() {
16+
// "rfd3" => {
17+
// // case 1
18+
// // ckpt_path=
19+
// }
20+
// "mpnn" => {
21+
// // case 2
22+
// // --checkpoint_path
23+
// }
24+
// _ => {
25+
// // default case
26+
// }
27+
// }
28+
29+
// ContainerRunSpec::with_prefixed_args(
30+
// "rosettacommons/foundry:weights",
31+
// std::iter::empty::<&str>(),
32+
// app_args.clone(),
33+
// )
34+
35+
ContainerRunSpec::new("rosettacommons/foundry:weights", app_args).working_dir("/w")
36+
}
37+
38+
pub fn native_spec(app_args: Vec<String>, _working_dir: &Path) -> NativeRunSpec {
39+
NativeRunSpec::new(include_asset!("pixi/foundry.toml"), app_args)
40+
}

src/app/proteinmpnn_script.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::app::ContainerRunSpec;
22

33
pub fn container_spec(mut app_args: Vec<String>) -> ContainerRunSpec {
44
assert!(
5-
app_args.is_empty() || app_args[0].starts_with("-"),
5+
!app_args.is_empty() || !app_args[0].starts_with("-"),
66
"ProteinmpnnScript arguments must include a script name as first argument"
77
);
88

tests/foundry.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
use assert_cmd::{assert::OutputAssertExt, cargo::cargo_bin_cmd};
2+
use assert_fs::TempDir;
3+
4+
mod common;
5+
6+
common::engine_tests!(foundry);
7+
8+
fn foundry(engine: &str) {
9+
use assert_fs::assert::PathAssert;
10+
11+
let root = std::path::PathBuf::from("target/foundry").join(engine);
12+
std::fs::create_dir_all(&root).expect("create engine testing dir");
13+
let work_dir = TempDir::new_in(root).expect("create temp dir");
14+
15+
let json_rfd3_path = work_dir.join("test_rfd3.json");
16+
std::fs::write(json_rfd3_path, r#"{ "foundry": { "length": "10" } }"#)
17+
.expect("write test_rfd3.json");
18+
19+
for (i, c) in [
20+
"rfd3 out_dir=/w/rfd3_out/ inputs=/w/test_rfd3.json skip_existing=False prevalidate_inputs=True ckpt_path=/weights/rfd3_latest.ckpt n_batches=1 diffusion_batch_size=1 inference_sampler.num_timesteps=10 low_memory_mode=True global_prefix=test_",
21+
"mpnn --structure_path /w/rfd3_out/test_foundry_0_model_0.cif.gz --checkpoint_path /weights/ligandmpnn_v_32_010_25.pt --is_legacy_weights True --model_type ligand_mpnn --out_directory /w/mpnn_out",
22+
"rf3 fold inputs=/w/mpnn_out/test_foundry_0_model_0.cif_b0_d0.cif ckpt_path=/weights/rf3_foundry_01_24_latest_remapped.ckpt diffusion_batch_size=1 num_steps=10 out_dir=/w/rf3_out",
23+
].iter()
24+
.enumerate() {
25+
let cmd = cargo_bin_cmd!()
26+
.args([
27+
"run",
28+
"--container-engine",
29+
engine,
30+
"-w",
31+
work_dir.path().to_str().unwrap(),
32+
"foundry",
33+
])
34+
.args(c.split_ascii_whitespace())
35+
.unwrap();
36+
cmd.assert().success();
37+
38+
use assert_fs::prelude::PathChild;
39+
40+
work_dir
41+
.child(format!(".000{i}.rc.log"))
42+
.assert(predicates::path::exists());
43+
}
44+
}

0 commit comments

Comments
 (0)