From 74f8056a8df6b9d48e2d201b63c29bcf73f4e0b2 Mon Sep 17 00:00:00 2001 From: Jiaming Zhang Date: Mon, 29 Sep 2025 12:20:52 +0200 Subject: [PATCH 1/2] Added handling of --- src/simplefold/cli.py | 1 + src/simplefold/inference.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/simplefold/cli.py b/src/simplefold/cli.py index 172b554..413dfd0 100644 --- a/src/simplefold/cli.py +++ b/src/simplefold/cli.py @@ -27,6 +27,7 @@ def main(): parser.add_argument("--plddt", action="store_true", help="Enable pLDDT prediction.") parser.add_argument("--output_format", type=str, default="mmcif", choices=["pdb", "mmcif"], help="Output file format.") parser.add_argument("--backend", type=str, default='torch', choices=['torch', 'mlx'], help="Backend to run inference either torch or mlx") + parser.add_argument("--cache", type=str, default=None, help="Specify the cache directory other than default.") parser.add_argument( "--version", action="version", diff --git a/src/simplefold/inference.py b/src/simplefold/inference.py index ca099c7..917ac11 100644 --- a/src/simplefold/inference.py +++ b/src/simplefold/inference.py @@ -256,7 +256,7 @@ def predict_structures_from_fastas(args): output_dir.mkdir(parents=True, exist_ok=True) prediction_dir = output_dir / f"predictions_{args.simplefold_model}" prediction_dir.mkdir(parents=True, exist_ok=True) - cache = output_dir / "cache" + cache = output_dir / "cache" if not args.cache else Path(args.cache) cache.mkdir(parents=True, exist_ok=True) if args.backend == "mlx" and not MLX_AVAILABLE: From 51b1e181889d372e30476b7de6564970533ec7ee Mon Sep 17 00:00:00 2001 From: Jiaming Zhang Date: Mon, 29 Sep 2025 20:26:21 +0200 Subject: [PATCH 2/2] Fix an issue causing file not found err when the model is ran outside of the repo --- src/simplefold/inference.py | 6 +++--- src/simplefold/wrapper.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/simplefold/inference.py b/src/simplefold/inference.py index 917ac11..4de35e0 100644 --- a/src/simplefold/inference.py +++ b/src/simplefold/inference.py @@ -61,7 +61,7 @@ def initialize_folding_model(args): if not os.path.exists(ckpt_path): os.makedirs(ckpt_dir, exist_ok=True) os.system(f"curl -L {ckpt_url_dict[simplefold_model]} -o {ckpt_path}") - cfg_path = os.path.join("configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml") + cfg_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml") checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) @@ -99,7 +99,7 @@ def initialize_plddt_module(args, device): os.makedirs(args.ckpt_dir, exist_ok=True) os.system(f"curl -L {plddt_ckpt_url} -o {plddt_ckpt_path}") - plddt_module_path = "configs/model/architecture/plddt_module.yaml" + plddt_module_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/plddt_module.yaml") plddt_checkpoint = torch.load(plddt_ckpt_path, map_location="cpu", weights_only=False) if args.backend == "torch": @@ -127,7 +127,7 @@ def initialize_plddt_module(args, device): os.makedirs(args.ckpt_dir, exist_ok=True) os.system(f"curl -L {ckpt_url_dict['simplefold_1.6B']} -o {plddt_latent_ckpt_path}") - plddt_latent_config_path = "configs/model/architecture/foldingdit_1.6B.yaml" + plddt_latent_config_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/foldingdit_1.6B.yaml") plddt_latent_checkpoint = torch.load(plddt_latent_ckpt_path, map_location="cpu", weights_only=False) if args.backend == "torch": diff --git a/src/simplefold/wrapper.py b/src/simplefold/wrapper.py index add3734..0481258 100644 --- a/src/simplefold/wrapper.py +++ b/src/simplefold/wrapper.py @@ -88,6 +88,7 @@ def from_pretrained_folding_model(self): # load model checkpoint cfg_path = os.path.join( + Path(__file__).parents[2], "configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml" ) if self.backend == "torch": @@ -126,7 +127,7 @@ def from_pretrained_plddt_model(self): if not os.path.exists(plddt_ckpt_path): os.system(f"curl -L -o {plddt_ckpt_path} {plddt_ckpt_url}") - plddt_module_path = "configs/model/architecture/plddt_module.yaml" + plddt_module_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/plddt_module.yaml") plddt_checkpoint = torch.load( plddt_ckpt_path, map_location="cpu", weights_only=False ) @@ -162,7 +163,7 @@ def from_pretrained_plddt_model(self): f"curl -L -o {plddt_latent_ckpt_path} {ckpt_url_dict['simplefold_1.6B']}" ) - plddt_latent_config_path = "configs/model/architecture/foldingdit_1.6B.yaml" + plddt_latent_config_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/foldingdit_1.6B.yaml") plddt_latent_checkpoint = torch.load( plddt_latent_ckpt_path, map_location="cpu", weights_only=False )