diff --git a/src/simplefold/cli.py b/src/simplefold/cli.py index 99230fa..2068424 100644 --- a/src/simplefold/cli.py +++ b/src/simplefold/cli.py @@ -27,6 +27,7 @@ def main(): parser.add_argument("--plddt", action="store_true", help="Enable pLDDT prediction.") parser.add_argument("--output_format", type=str, default="mmcif", choices=["pdb", "mmcif"], help="Output file format.") parser.add_argument("--backend", type=str, default='torch', choices=['torch', 'mlx'], help="Backend to run inference either torch or mlx") + parser.add_argument("--cache", type=str, default=None, help="Specify the cache directory other than default.") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") parser.add_argument( "--version", diff --git a/src/simplefold/inference.py b/src/simplefold/inference.py index 3221b85..7b6533f 100644 --- a/src/simplefold/inference.py +++ b/src/simplefold/inference.py @@ -62,7 +62,7 @@ def initialize_folding_model(args): if not os.path.exists(ckpt_path): os.makedirs(ckpt_dir, exist_ok=True) os.system(f"curl -L {ckpt_url_dict[simplefold_model]} -o {ckpt_path}") - cfg_path = os.path.join("configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml") + cfg_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml") checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) @@ -100,7 +100,7 @@ def initialize_plddt_module(args, device): os.makedirs(args.ckpt_dir, exist_ok=True) os.system(f"curl -L {plddt_ckpt_url} -o {plddt_ckpt_path}") - plddt_module_path = "configs/model/architecture/plddt_module.yaml" + plddt_module_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/plddt_module.yaml") plddt_checkpoint = torch.load(plddt_ckpt_path, map_location="cpu", weights_only=False) if args.backend == "torch": @@ -128,7 +128,7 @@ def initialize_plddt_module(args, device): os.makedirs(args.ckpt_dir, exist_ok=True) os.system(f"curl -L {ckpt_url_dict['simplefold_1.6B']} -o {plddt_latent_ckpt_path}") - plddt_latent_config_path = "configs/model/architecture/foldingdit_1.6B.yaml" + plddt_latent_config_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/foldingdit_1.6B.yaml") plddt_latent_checkpoint = torch.load(plddt_latent_ckpt_path, map_location="cpu", weights_only=False) if args.backend == "torch": @@ -257,7 +257,7 @@ def predict_structures_from_fastas(args): output_dir.mkdir(parents=True, exist_ok=True) prediction_dir = output_dir / f"predictions_{args.simplefold_model}" prediction_dir.mkdir(parents=True, exist_ok=True) - cache = output_dir / "cache" + cache = output_dir / "cache" if not args.cache else Path(args.cache) cache.mkdir(parents=True, exist_ok=True) # set random seed for reproducibility diff --git a/src/simplefold/wrapper.py b/src/simplefold/wrapper.py index add3734..0481258 100644 --- a/src/simplefold/wrapper.py +++ b/src/simplefold/wrapper.py @@ -88,6 +88,7 @@ def from_pretrained_folding_model(self): # load model checkpoint cfg_path = os.path.join( + Path(__file__).parents[2], "configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml" ) if self.backend == "torch": @@ -126,7 +127,7 @@ def from_pretrained_plddt_model(self): if not os.path.exists(plddt_ckpt_path): os.system(f"curl -L -o {plddt_ckpt_path} {plddt_ckpt_url}") - plddt_module_path = "configs/model/architecture/plddt_module.yaml" + plddt_module_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/plddt_module.yaml") plddt_checkpoint = torch.load( plddt_ckpt_path, map_location="cpu", weights_only=False ) @@ -162,7 +163,7 @@ def from_pretrained_plddt_model(self): f"curl -L -o {plddt_latent_ckpt_path} {ckpt_url_dict['simplefold_1.6B']}" ) - plddt_latent_config_path = "configs/model/architecture/foldingdit_1.6B.yaml" + plddt_latent_config_path = os.path.join(Path(__file__).parents[2], "configs/model/architecture/foldingdit_1.6B.yaml") plddt_latent_checkpoint = torch.load( plddt_latent_ckpt_path, map_location="cpu", weights_only=False )