From 14cab04dd3872f30f3ae4e94f3a6404b8a343615 Mon Sep 17 00:00:00 2001 From: carrascomj Date: Thu, 6 Nov 2025 13:28:24 +0100 Subject: [PATCH] Remove inference_multiplicty Remove inference_multiplicity altogether --- configs/experiment/test.yaml | 3 - src/simplefold/inference.py | 86 ++++++++++++------- src/simplefold/processor/protein_processor.py | 6 +- src/simplefold/wrapper.py | 85 +++++++++++------- 4 files changed, 107 insertions(+), 73 deletions(-) diff --git a/configs/experiment/test.yaml b/configs/experiment/test.yaml index 890c017..f8793e6 100644 --- a/configs/experiment/test.yaml +++ b/configs/experiment/test.yaml @@ -25,6 +25,3 @@ model: tau: 0.01 log_timesteps: True w_cutoff: 0.99 - - processor: - inference_multiplicity: 1 diff --git a/src/simplefold/inference.py b/src/simplefold/inference.py index 3221b85..e1b701a 100644 --- a/src/simplefold/inference.py +++ b/src/simplefold/inference.py @@ -183,7 +183,6 @@ def initialize_others(args, device): scale=16.0, ref_scale=5.0, multiplicity=1, - inference_multiplicity=args.nsample_per_protein, backend=args.backend, ) @@ -210,44 +209,65 @@ def generate_structure( model, plddt_latent_module, plddt_out_module, device ): # run inference for target protein - if args.backend == "torch": - noise = torch.randn_like(batch['coords']).to(device) - elif args.backend == "mlx": - noise = mx.random.normal(batch['coords'].shape) - out_dict = sampler.sample(model, flow, noise, batch) - - if args.plddt: + coord_samples = [] + pad_mask = batch["atom_pad_mask"] + if args.plddt and plddt_latent_module is not None and plddt_out_module is not None: + compute_plddt = True + plddt_samples = [] + else: + compute_plddt = False + plddt_samples = None + plddts = None + for _ in range(args.nsample_per_protein): if args.backend == "torch": - t = torch.ones(batch['coords'].shape[0], device=device) - # use unscaled coords to extract latent for pLDDT prediction - out_feat = plddt_latent_module( - out_dict["denoised_coords"].detach(), t, batch) - plddt_out_dict = plddt_out_module( - out_feat["latent"].detach(), - batch, - ) + noise = torch.randn_like(batch["coords"]).to(device) elif args.backend == "mlx": - t = mx.ones(batch['coords'].shape[0]) - # use unscaled coords to extract latent for pLDDT prediction - out_feat = plddt_latent_module( - out_dict["denoised_coords"], t, batch) - plddt_out_dict = plddt_out_module( - out_feat["latent"], - batch, - ) - # scale pLDDT to [0, 100] - plddts = plddt_out_dict["plddt"] * 100.0 - else: - plddts = None + noise = mx.random.normal(batch["coords"].shape) + out_dict = sampler.sample(model, flow, noise, batch) + + if compute_plddt: + if args.backend == "torch": + t = torch.ones(batch['coords'].shape[0], device=device) + out_feat = plddt_latent_module( + out_dict["denoised_coords"].detach(), t, batch + ) + plddt_out_dict = plddt_out_module( + out_feat["latent"].detach(), + batch, + ) + elif args.backend == "mlx": + t = mx.ones(batch['coords'].shape[0]) + out_feat = plddt_latent_module( + out_dict["denoised_coords"], t, batch + ) + plddt_out_dict = plddt_out_module( + out_feat["latent"], + batch, + ) + # scale pLDDT to [0, 100] + plddt_samples.append(plddt_out_dict["plddt"] * 100.0) + + out_dict = processor.postprocess(out_dict, batch) + if args.backend == "torch": + coord_samples.append(out_dict["denoised_coords"].detach()) + else: + coord_samples.append(out_dict["denoised_coords"]) - out_dict = processor.postprocess(out_dict, batch) - # sampled_coord = out_dict['denoised_coords'].detach() if args.backend == "torch": - sampled_coord = out_dict['denoised_coords'].detach() + sampled_coord = torch.cat(coord_samples, dim=0) + pad_mask = pad_mask.detach().repeat_interleave( + args.nsample_per_protein, dim=0 + ) + if compute_plddt: + plddts = torch.cat(plddt_samples, dim=0).detach() else: - sampled_coord = out_dict['denoised_coords'] + sampled_coord = mx.concatenate(coord_samples, axis=0) + pad_mask = mx.concatenate( + [pad_mask] * args.nsample_per_protein, axis=0 + ) + if compute_plddt: + plddts = mx.concatenate(plddt_samples, axis=0) - pad_mask = batch['atom_pad_mask'] return sampled_coord, pad_mask, plddts diff --git a/src/simplefold/processor/protein_processor.py b/src/simplefold/processor/protein_processor.py index 2de580a..a7bc7ff 100644 --- a/src/simplefold/processor/protein_processor.py +++ b/src/simplefold/processor/protein_processor.py @@ -26,7 +26,6 @@ def __init__( scale=16.0, ref_scale=5.0, multiplicity=1, - inference_multiplicity=1, backend="torch", ): self.device = device @@ -34,7 +33,6 @@ def __init__( self.ref_scale = ref_scale # if multiplicity > 1, effective batch size is multiplicity * batch_size self.multiplicity = multiplicity - self.inference_multiplicity = inference_multiplicity self.backend = backend if self.backend == "mlx": self.center_random_fn = mlx_center_random @@ -69,7 +67,7 @@ def process_esm( esmaa = af2_idx_to_esm_idx(aatype, mask, af2_to_esm) - multiplicity = self.multiplicity if not inference else self.inference_multiplicity + multiplicity = self.multiplicity if not inference else 1 esm_s_, _ = compute_language_model_representations( esmaa, esm_model, esm_dict, backend=self.backend @@ -167,7 +165,7 @@ def preprocess_inference(self, batch, esm_model=None, esm_dict=None, af2_to_esm= batch_size, -1) batch['mol_index'] = mol_index - batch = self.batch_to_device(batch, multiplicity=self.inference_multiplicity) + batch = self.batch_to_device(batch) if esm_model is not None and batch.get('esm_s', None) is None: print("Processing ESM features for inference...") diff --git a/src/simplefold/wrapper.py b/src/simplefold/wrapper.py index add3734..e99be01 100644 --- a/src/simplefold/wrapper.py +++ b/src/simplefold/wrapper.py @@ -270,7 +270,6 @@ def initialize_others(self): scale=16.0, ref_scale=5.0, multiplicity=1, - inference_multiplicity=self.nsample_per_protein, backend=self.backend, ) @@ -320,49 +319,69 @@ def process_input(self, aa_seq): def run_inference(self, batch, model, plddt_model, device): # run inference for target protein - if self.backend == "torch": - noise = torch.randn_like(batch["coords"]).to(device) - elif self.backend == "mlx": - noise = mx.random.normal(batch["coords"].shape) - out_dict = self.sampler.sample(model, self.flow, noise, batch) - plddt_out_module = plddt_model["plddt_out_module"] plddt_latent_module = plddt_model["plddt_latent_module"] - + coord_samples = [] if plddt_latent_module is None or plddt_out_module is None: - plddts = None + compute_plddt = False + plddt_samples = None else: + compute_plddt = True + plddt_samples = [] + pad_mask = batch["atom_pad_mask"] + plddts = None + + for _ in range(self.nsample_per_protein): if self.backend == "torch": - t = torch.ones(batch["coords"].shape[0], device=device) - # use unscaled coords to extract latent for pLDDT prediction - out_feat = plddt_latent_module( - out_dict["denoised_coords"].detach(), t, batch - ) - plddt_out_dict = plddt_out_module( - out_feat["latent"].detach(), - batch, - ) + noise = torch.randn_like(batch["coords"]).to(device) elif self.backend == "mlx": - t = mx.ones(batch["coords"].shape[0]) - # use unscaled coords to extract latent for pLDDT prediction - out_feat = plddt_latent_module(out_dict["denoised_coords"], t, batch) - plddt_out_dict = plddt_out_module( - out_feat["latent"], - batch, - ) - # scale pLDDT to [0, 100] - plddts = plddt_out_dict["plddt"] * 100.0 - - out_dict = self.processor.postprocess(out_dict, batch) - # sampled_coord = out_dict['denoised_coords'].detach() + noise = mx.random.normal(batch["coords"].shape) + out_dict = self.sampler.sample(model, self.flow, noise, batch) + + if compute_plddt: + if self.backend == "torch": + t = torch.ones(batch["coords"].shape[0], device=device) + out_feat = plddt_latent_module( + out_dict["denoised_coords"].detach(), t, batch + ) + plddt_out_dict = plddt_out_module( + out_feat["latent"].detach(), + batch, + ) + elif self.backend == "mlx": + t = mx.ones(batch["coords"].shape[0]) + out_feat = plddt_latent_module(out_dict["denoised_coords"], t, batch) + plddt_out_dict = plddt_out_module( + out_feat["latent"], + batch, + ) + # scale pLDDT to [0, 100] + plddt_samples.append(plddt_out_dict["plddt"] * 100.0) + + out_dict = self.processor.postprocess(out_dict, batch) + if self.backend == "torch": + coord_samples.append(out_dict["denoised_coords"].detach()) + else: + coord_samples.append(out_dict["denoised_coords"]) + if self.backend == "torch": - sampled_coord = out_dict["denoised_coords"].detach() + sampled_coord = torch.cat(coord_samples, dim=0) + pad_mask = pad_mask.detach().repeat_interleave( + self.nsample_per_protein, dim=0 + ) + if compute_plddt: + plddts = torch.cat(plddt_samples, dim=0).detach() else: - sampled_coord = out_dict["denoised_coords"] + sampled_coord = mx.concatenate(coord_samples, axis=0) + pad_mask = mx.concatenate( + [pad_mask] * self.nsample_per_protein, axis=0 + ) + if compute_plddt: + plddts = mx.concatenate(plddt_samples, axis=0) return { "sampled_coord": sampled_coord, - "pad_mask": batch["atom_pad_mask"], + "pad_mask": pad_mask, "plddts": plddts, }