diff --git a/confidence_threshold b/confidence_threshold new file mode 100644 index 000000000..e69de29bb diff --git a/openfold/config.py b/openfold/config.py index 7bf30e391..19a6480c2 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -364,7 +364,7 @@ def model_config( "same_prob": 0.1, "uniform_prob": 0.1, }, - "max_recycling_iters": 3, + "max_recycling_iters": 0, # changed from 3 "msa_cluster_features": True, "reduce_msa_clusters_by_max_templates": False, "resample_msa_in_recycling": True, @@ -420,13 +420,13 @@ def model_config( "fixed_size": True, "subsample_templates": False, # We want top templates. "block_delete_msa": False, - "masked_msa_replace_fraction": 0.15, + "masked_msa_replace_fraction": 0.0, #from 0.15 "max_msa_clusters": 128, "max_extra_msa": 1024, "max_template_hits": 4, "max_templates": 4, - "crop": False, - "crop_size": None, + "crop": True, + "crop_size": 150, "spatial_crop_prob": None, "interface_threshold": None, "supervised": True, @@ -436,14 +436,14 @@ def model_config( "fixed_size": True, "subsample_templates": True, "block_delete_msa": True, - "masked_msa_replace_fraction": 0.15, + "masked_msa_replace_fraction": 0.0, #from 0.15 "max_msa_clusters": 128, "max_extra_msa": 1024, "max_template_hits": 4, "max_templates": 4, "shuffle_top_k_prefiltered": 20, "crop": True, - "crop_size": 256, + "crop_size": 150, # TODO: change back to 256 ? "spatial_crop_prob": 0., "interface_threshold": None, "supervised": True, @@ -792,7 +792,7 @@ def model_config( ], "true_msa": [NUM_MSA_SEQ, NUM_RES] }, - "max_recycling_iters": 20, # For training, value is 3 + "max_recycling_iters": 0, # changed from 20 # For training, value is 3 "unsupervised_features": [ "aatype", "residue_index", diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index c6b5bc0cb..94d10c0e1 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import ( tensor_tree_map, ) - +import pandas as pd class OpenFoldSingleDataset(torch.utils.data.Dataset): def __init__(self, @@ -87,6 +87,9 @@ def __init__(self, super(OpenFoldSingleDataset, self).__init__() self.data_dir = data_dir + print("dataloader getting reloaded !") + + self.chain_data_cache = None if chain_data_cache_path is not None: with open(chain_data_cache_path, "r") as fp: @@ -107,17 +110,46 @@ def __init__(self, if mode not in valid_modes: raise ValueError(f'mode must be one of {valid_modes}') + self.is_esm = True + + self.df = None + if mode == "train": + # load this csv /home/j-quentin/openfold_small_data/df_train.csv in a dataframe + self.df = pd.read_csv("/home/j-quentin/openfold_small_data/df_train.csv", index_col=0) + # self.df = pd.read_csv("/home/j-quentin/openfold_small_data/df_val.csv", index_col=0) + + elif mode == "eval": + self.df = pd.read_csv("/home/j-quentin/openfold_small_data/df_val.csv", index_col=0) + else: + raise NotImplementedError("mode not implemented") + if template_release_dates_cache_path is None: logging.warning( "Template release dates cache does not exist. Remember to run " "scripts/generate_mmcif_cache.py before running OpenFold" ) + # import pdb;pdb.set_trace() + if alignment_index is not None: self._chain_ids = list(alignment_index.keys()) else: self._chain_ids = list(os.listdir(alignment_dir)) + # import pdb; pdb.set_trace() + filtered_ids = self.df.index.tolist() + self._chain_ids = [k for k in self._chain_ids if k in filtered_ids] + print(f"Loaded {len(self._chain_ids)} chains in {mode} mode") + + # we shuffle to sample afterwards + # import random + # random.shuffle(self._chain_ids) + + # import pdb; pdb.set_trace() + + # self._chain_ids = self._chain_ids[:500] + # print(self._chain_ids) + if filter_path is not None: with open(filter_path, "r") as f: chains_to_include = set([l.strip() for l in f.readlines()]) @@ -173,6 +205,101 @@ def __init__(self, if not self._output_raw: self.feature_pipeline = feature_pipeline.FeaturePipeline(config) + from transformers import AutoTokenizer, EsmForProteinFolding, BitsAndBytesConfig + self.tokenizer_fold = AutoTokenizer.from_pretrained("facebook/esmfold_v1") + + + # import pdb; pdb.set_trace() + # idx=0 + # name = self.idx_to_chain_id(idx) + # print(name) + # alignment_dir = os.path.join(self.alignment_dir, name) + + # alignment_index = None + # if self.alignment_index is not None: + # alignment_dir = self.alignment_dir + # alignment_index = self.alignment_index[name] + + # if self.mode == 'train' or self.mode == 'eval': + # spl = name.rsplit('_', 1) + # if len(spl) == 2: + # file_id, chain_id = spl + # else: + # file_id, = spl + # chain_id = None + + # path = os.path.join(self.data_dir, file_id) + # if self._structure_index is not None: + # structure_index_entry = self._structure_index[name] + # assert (len(structure_index_entry["files"]) == 1) + # filename, _, _ = structure_index_entry["files"][0] + # ext = os.path.splitext(filename)[1] + # else: + # ext = None + # for e in self.supported_exts: + # if os.path.exists(path + e): + # ext = e + # break + + # if ext is None: + # raise ValueError("Invalid file type") + + # path += ext + # if ext == ".cif": + # data = self._parse_mmcif( + # path, file_id, chain_id, alignment_dir, alignment_index, + # ) + # elif ext == ".core": + # data = self.data_pipeline.process_core( + # path, alignment_dir, alignment_index, + # seqemb_mode=self.config.seqemb_mode.enabled, + # ) + # elif ext == ".pdb": + # structure_index = None + # if self._structure_index is not None: + # structure_index = self._structure_index[name] + # data = self.data_pipeline.process_pdb( + # pdb_path=path, + # alignment_dir=alignment_dir, + # is_distillation=self.treat_pdb_as_distillation, + # chain_id=chain_id, + # alignment_index=alignment_index, + # _structure_index=structure_index, + # seqemb_mode=self.config.seqemb_mode.enabled, + # ) + # else: + # raise ValueError("Extension branch missing") + # else: + # path = os.path.join(name, name + ".fasta") + # data = self.data_pipeline.process_fasta( + # fasta_path=path, + # alignment_dir=alignment_dir, + # alignment_index=alignment_index, + # seqemb_mode=self.config.seqemb_mode.enabled, + # ) + + # if self._output_raw: + # return data + + # feats = self.feature_pipeline.process_features( + # data, self.mode + # ) + + # feats["batch_idx"] = torch.tensor( + # [idx for _ in range(feats["aatype"].shape[-1])], + # dtype=torch.int64, + # device=feats["aatype"].device) + + # encoding = 'utf-8' + # seq = str( data["sequence"][0], encoding) + # tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding=True)['input_ids'] + # tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, truncation=True, max_length=120)['input_ids'] + + # factor = feats["use_clamped_fape"].shape[0] + # feats["sequence"] = tokenized_input.unsqueeze(-1).repeat(1,1,factor) + + + def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index): with open(path, 'r') as f: mmcif_string = f.read() @@ -204,8 +331,109 @@ def chain_id_to_idx(self, chain_id): def idx_to_chain_id(self, idx): return self._chain_ids[idx] + def custom_getitem(self, idx): + name = self.idx_to_chain_id(idx) + print("starting get item") + print(name) + alignment_dir = os.path.join(self.alignment_dir, name) + + alignment_index = None + if self.alignment_index is not None: + alignment_dir = self.alignment_dir + alignment_index = self.alignment_index[name] + + if self.mode == 'train' or self.mode == 'eval': + spl = name.rsplit('_', 1) + if len(spl) == 2: + file_id, chain_id = spl + else: + file_id, = spl + chain_id = None + + path = os.path.join(self.data_dir, file_id) + if self._structure_index is not None: + structure_index_entry = self._structure_index[name] + assert (len(structure_index_entry["files"]) == 1) + filename, _, _ = structure_index_entry["files"][0] + ext = os.path.splitext(filename)[1] + else: + ext = None + for e in self.supported_exts: + if os.path.exists(path + e): + ext = e + break + + if ext is None: + raise ValueError("Invalid file type") + + path += ext + if ext == ".cif": + data = self._parse_mmcif( + path, file_id, chain_id, alignment_dir, alignment_index, + ) + elif ext == ".core": + data = self.data_pipeline.process_core( + path, alignment_dir, alignment_index, + seqemb_mode=self.config.seqemb_mode.enabled, + ) + elif ext == ".pdb": + structure_index = None + if self._structure_index is not None: + structure_index = self._structure_index[name] + data = self.data_pipeline.process_pdb( + pdb_path=path, + alignment_dir=alignment_dir, + is_distillation=self.treat_pdb_as_distillation, + chain_id=chain_id, + alignment_index=alignment_index, + _structure_index=structure_index, + seqemb_mode=self.config.seqemb_mode.enabled, + ) + else: + raise ValueError("Extension branch missing") + else: + path = os.path.join(name, name + ".fasta") + data = self.data_pipeline.process_fasta( + fasta_path=path, + alignment_dir=alignment_dir, + alignment_index=alignment_index, + seqemb_mode=self.config.seqemb_mode.enabled, + ) + + if self._output_raw: + return data + + feats = self.feature_pipeline.process_features( + data, self.mode + ) + + + feats["batch_idx"] = torch.tensor( + [idx for _ in range(feats["aatype"].shape[-1])], + dtype=torch.int64, + device=feats["aatype"].device) + + # new + encoding = 'utf-8' + seq = str( data["sequence"][0], encoding) + print(len(seq)) + # tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=120)['input_ids'] + # tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=120, truncation=True)['input_ids'] + + # # quick fix to bypass recycling + # factor = feats["use_clamped_fape"].shape[0] + # feats["sequence"] = tokenized_input.unsqueeze(-1).repeat(1,1,factor) + + # print("finishing get item") + + return len(seq) + + def __getitem__(self, idx): + name = self.idx_to_chain_id(idx) + print("starting get item") + print(name) alignment_dir = os.path.join(self.alignment_dir, name) alignment_index = None @@ -277,12 +505,43 @@ def __getitem__(self, idx): feats = self.feature_pipeline.process_features( data, self.mode ) + feats["batch_idx"] = torch.tensor( [idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device) + # sequence = self.df.iloc[0]["seq"] + # print(f"overuled get item idx=0") + # print(sequence) + # modeling worked on : TRDQNGTWEMESNENFEGYMKALDIDFATRKIAVRLTQTLVIDQDGDNFKVKTTSTFFNYDVDFTVGVEFDEYTKSLDNRHVKALVTWEGDVLVCVQKGEKENRGWKKWIEGDKLYLELTCGDQVCRQVFKKK + + if self.is_esm: + # new + encoding = 'utf-8' + seq = str( data["sequence"][0], encoding) + print(seq) + # print(f"overuled get item idx=0") + # seq = sequence + + + if self.df[self.df.index==name].iloc[0]["seq"] != seq: + print("wrong sequence!") + # import pdb;pdb.set_trace() + else: + print("right sequence!") + + print(len(seq)) + # tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=120)['input_ids'] + tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=150, truncation=True)['input_ids'] + + # quick fix to bypass recycling + factor = feats["use_clamped_fape"].shape[0] + feats["sequence"] = tokenized_input.unsqueeze(-1).repeat(1,1,factor) + + print("finishing get item") + return feats def __len__(self): @@ -450,6 +709,7 @@ def idx_to_mmcif_id(self, idx): return self._mmcifs[idx] def __getitem__(self, idx): + mmcif_id = self.idx_to_mmcif_id(idx) alignment_index = None @@ -837,10 +1097,12 @@ def _add_batch_properties(self, batch): def __iter__(self): it = super().__iter__() - + print("starting iter OpenFoldDataLoader") def _batch_prop_gen(iterator): + print("called _batch_prop_gen") for batch in iterator: yield self._add_batch_properties(batch) + print("stopping iter OpenFoldDataLoader") return _batch_prop_gen(it) diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index 393c1cef3..c33fb4adb 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -927,6 +927,8 @@ def process_mmcif( If chain_id is None, it is assumed that there is only one chain in the object. Otherwise, a ValueError is thrown. """ + + # import pdb; pdb.set_trace() if chain_id is None: chains = mmcif.structure.get_chains() chain = next(chains, None) @@ -936,7 +938,12 @@ def process_mmcif( mmcif_feats = make_mmcif_features(mmcif, chain_id) - input_sequence = mmcif.chain_to_seqres[chain_id] + input_sequence = mmcif.chain_to_seqres[chain_id] + #AQVINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLAYKTTDHYQTFTKIR + + #MPVRKAKAVWEGGLRQGKGVMELQSQAFQGPYSYPSRFEEGEGTNPEELIAAAHAGCFSMALAASLEREGFPPKRVSTEARVHLEVVDGKPTLTRIELLTEAEVPGISSEKFLEIAEAAKEGCPVSRALAGVKEVVLTARLV + #MPVRKAKAVWEGGLRQGKGVMELQSQAFQGPYSYPSRFEEGEGTNPEELIAAAHAGCFSMALAASLEREGFPPKRVSTEARVHLEVVDGKPTLTRIELLTEAEVPGISSEKFLEIAEAAKEGCPVSRALAGVKEVVLTARLV + hits = self._parse_template_hit_files( alignment_dir=alignment_dir, input_sequence=input_sequence, diff --git a/openfold/utils/exponential_moving_average.py b/openfold/utils/exponential_moving_average.py index 731615262..e2d5d520f 100644 --- a/openfold/utils/exponential_moving_average.py +++ b/openfold/utils/exponential_moving_average.py @@ -55,7 +55,8 @@ def update(self, model: torch.nn.Module) -> None: module. The module should have the same structure as that used to initialize the ExponentialMovingAverage object. """ - self._update_state_dict_(model.state_dict(), self.params) + # self._update_state_dict_(model.state_dict(), self.params) + print("bypass_upadte_moving_average") def load_state_dict(self, state_dict: OrderedDict) -> None: for k in state_dict["params"].keys(): diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index 395e34753..1287c3d00 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -1192,16 +1192,17 @@ def find_structural_violations( ) -> Dict[str, torch.Tensor]: """Computes several checks for structural violations.""" + print("bb1") # Compute between residue backbone violations of bonds and angles. connection_violations = between_residue_bond_loss( pred_atom_positions=atom14_pred_positions, - pred_atom_mask=batch["atom14_atom_exists"], - residue_index=batch["residue_index"], - aatype=batch["aatype"], + pred_atom_mask=batch["atom14_atom_exists"], #torch.Size([1, 92, 14]) + residue_index=batch["residue_index"], #torch.Size([1, 92]) + aatype=batch["aatype"], #torch.Size([1, 92]) tolerance_factor_soft=violation_tolerance_factor, tolerance_factor_hard=violation_tolerance_factor, ) - + print("bb2") # Compute the Van der Waals radius for every atom # (the first letter of the atom name is the element type). # Shape: (N, 14). @@ -1209,9 +1210,9 @@ def find_structural_violations( residue_constants.van_der_waals_radius[name[0]] for name in residue_constants.atom_types ] - + print("bb3") atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) - + print("bb4") # TODO: Consolidate monomer/multimer modes asym_id = batch.get("asym_id") if asym_id is not None: @@ -1227,7 +1228,7 @@ def find_structural_violations( batch["atom14_atom_exists"] * atomtype_radius[batch["residx_atom14_to_atom37"]] ) - + print("bb5") # Compute the between residue clash loss. between_residue_clashes = between_residue_clash_loss( atom14_pred_positions=atom14_pred_positions, @@ -1238,13 +1239,14 @@ def find_structural_violations( overlap_tolerance_soft=clash_overlap_tolerance, overlap_tolerance_hard=clash_overlap_tolerance, ) - + print("bb6") # Compute all within-residue violations (clashes, # bond length and angle violations). restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( overlap_tolerance=clash_overlap_tolerance, bond_length_tolerance_factor=violation_tolerance_factor, ) + print("bb7") atom14_atom_exists = batch["atom14_atom_exists"] atom14_dists_lower_bound = atom14_pred_positions.new_tensor( restype_atom14_bounds["lower_bound"] @@ -1259,7 +1261,7 @@ def find_structural_violations( atom14_dists_upper_bound=atom14_dists_upper_bound, tighten_bounds_for_loss=0.0, ) - + print("bb8") # Combine them to a single per-residue violation mask (used later for LDDT). per_residue_violations_mask = torch.max( torch.stack( @@ -1492,7 +1494,7 @@ def compute_renamed_ground_truth( after renaming swaps are performed. renamed_atom14_gt_exists: Mask after renaming swap is performed. """ - + print("pp1") pred_dists = torch.sqrt( eps + torch.sum( @@ -1504,7 +1506,7 @@ def compute_renamed_ground_truth( dim=-1, ) ) - + print("pp2") atom14_gt_positions = batch["atom14_gt_positions"] gt_dists = torch.sqrt( eps @@ -1517,7 +1519,7 @@ def compute_renamed_ground_truth( dim=-1, ) ) - + print("pp3") atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] alt_gt_dists = torch.sqrt( eps @@ -1530,10 +1532,10 @@ def compute_renamed_ground_truth( dim=-1, ) ) - + print("pp4") lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2) alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2) - + print("pp5") atom14_gt_exists = batch["atom14_gt_exists"] atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] mask = ( @@ -1542,7 +1544,7 @@ def compute_renamed_ground_truth( * atom14_gt_exists[..., None, :, None, :] * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :]) ) - + print("pp6") per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) @@ -1688,72 +1690,74 @@ class AlphaFoldLoss(nn.Module): def __init__(self, config): super(AlphaFoldLoss, self).__init__() self.config = config + print("hello jq") def loss(self, out, batch, _return_breakdown=False): """ Rename previous forward() as loss() so that can be reused in the subclass """ + # import pdb; pdb.set_trace() + print("hello loss") + print("bp1") if "violation" not in out.keys(): out["violation"] = find_structural_violations( batch, - out["sm"]["positions"][-1], + out["sm"]["positions"][-1], #torch.Size([1, 120, 14, 3]) **self.config.violation, ) - + print("bp2") if "renamed_atom14_gt_positions" not in out.keys(): - batch.update( - compute_renamed_ground_truth( - batch, - out["sm"]["positions"][-1], - ) - ) + batch.update(compute_renamed_ground_truth(batch,out["sm"]["positions"][-1],)) + + print("bp3") loss_fns = { "distogram": lambda: distogram_loss( logits=out["distogram_logits"], **{**batch, **self.config.distogram}, ), - "experimentally_resolved": lambda: experimentally_resolved_loss( - logits=out["experimentally_resolved_logits"], - **{**batch, **self.config.experimentally_resolved}, - ), + # "experimentally_resolved": lambda: experimentally_resolved_loss( + # logits=out["experimentally_resolved_logits"], + # **{**batch, **self.config.experimentally_resolved}, + # ), "fape": lambda: fape_loss( out, batch, self.config.fape, ), - "plddt_loss": lambda: lddt_loss( - logits=out["lddt_logits"], - all_atom_pred_pos=out["final_atom_positions"], - **{**batch, **self.config.plddt_loss}, - ), - "masked_msa": lambda: masked_msa_loss( - logits=out["masked_msa_logits"], - **{**batch, **self.config.masked_msa}, - ), - "supervised_chi": lambda: supervised_chi_loss( - out["sm"]["angles"], - out["sm"]["unnormalized_angles"], - **{**batch, **self.config.supervised_chi}, - ), + # "plddt_loss": lambda: lddt_loss( # TODO: Fix this + # logits=out["lddt_logits"], + # all_atom_pred_pos=out["final_atom_positions"], + # **{**batch, **self.config.plddt_loss}, + # ), + # "masked_msa": lambda: masked_msa_loss( + # logits=out["masked_msa_logits"], + # **{**batch, **self.config.masked_msa}, + # ), + # "supervised_chi": lambda: supervised_chi_loss( + # out["sm"]["angles"], + # out["sm"]["unnormalized_angles"], + # **{**batch, **self.config.supervised_chi}, + # ), "violation": lambda: violation_loss( out["violation"], **{**batch, **self.config.violation}, ), } - if self.config.tm.enabled: - loss_fns["tm"] = lambda: tm_loss( - logits=out["tm_logits"], - **{**batch, **out, **self.config.tm}, - ) + # if self.config.tm.enabled: + # loss_fns["tm"] = lambda: tm_loss( + # logits=out["tm_logits"], + # **{**batch, **out, **self.config.tm}, + # ) - if self.config.chain_center_of_mass.enabled: - loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss( - all_atom_pred_pos=out["final_atom_positions"], - **{**batch, **self.config.chain_center_of_mass}, - ) + # if self.config.chain_center_of_mass.enabled: + # loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss( + # all_atom_pred_pos=out["final_atom_positions"], + # **{**batch, **self.config.chain_center_of_mass}, + # ) + print("bp4") cum_loss = 0. losses = {} @@ -1770,6 +1774,7 @@ def loss(self, out, batch, _return_breakdown=False): cum_loss = cum_loss + weight * loss losses[loss_name] = loss.detach().clone() losses["unscaled_loss"] = cum_loss.detach().clone() + print("bp5") # Scale the loss by the square root of the minimum of the crop size and # the (average) sequence length. See subsection 1.9. diff --git a/train_openfold.py b/train_openfold.py index 168a4b43f..48adae2b5 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -41,37 +41,103 @@ import_openfold_weights_ ) from openfold.utils.logger import PerformanceLoggingCallback +from transformers import AutoTokenizer, EsmForProteinFolding, BitsAndBytesConfig +# from transformers import AutoModelForMaskedLM +# from peft import PeftModel +from openfold.utils.feats import atom14_to_atom37 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class OpenFoldWrapper(pl.LightningModule): def __init__(self, config): super(OpenFoldWrapper, self).__init__() self.config = config - self.model = AlphaFold(config) + + self.use_esm = True + + if self.use_esm: + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # no training + model_fold = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=False) + + # far end of training + # model_fold = EsmForProteinFolding.from_pretrained("/home/j-quentin/bnt-vhh-modelling/notebooks/full_training_clustered_esm3B/merged_weights-checkpoint-118560/", low_cpu_mem_usage=False) + + # far end of training + # model_fold = EsmForProteinFolding.from_pretrained("/home/j-quentin/bnt-vhh-modelling/notebooks/full_training_clustered_esm3B/merged_weights-checkpoint-3120/", low_cpu_mem_usage=False) + + self.tokenizer_fold = AutoTokenizer.from_pretrained("facebook/esmfold_v1") + + model_fold = model_fold.to(device) + + self.model = model_fold + + # just to check + self.model.requires_grad_(True) + self.model.esm.requires_grad_(False) + + else: + self.model = AlphaFold(config) + # self.model = None + + self.is_multimer = self.config.globals.is_multimer + # import pdb;pdb.set_trace() self.loss = AlphaFoldLoss(config.loss) self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) + # self.ema=None self.cached_weights = None self.last_lr_step = -1 self.save_hyperparameters() def forward(self, batch): - return self.model(batch) + + if self.use_esm: + + # if batch["sequence"].shape!=torch.Size([1, 1, 120, 1]): + # raise NotImplementedError + + + outputs = self.model(batch["sequence"].squeeze(-1).squeeze(0)) + outputs["sm"] = {} + # outputs["sm"]["frames"] = outputs["frames"].repeat(1, 1, 8, 1) + # outputs["sm"]["sidechain_frames"] = outputs["sidechain_frames"].repeat(1, 1, 8, 1, 1, 1) + # outputs["sm"]["positions"] = outputs["positions"].repeat(1, 1, 8, 1, 1) + + outputs["sm"]["frames"] = outputs["frames"] #torch.Size([8, 1, 120, 7]) + outputs["sm"]["sidechain_frames"] = outputs["sidechain_frames"] #torch.Size([8, 1, 120, 8, 4, 4]) + outputs["sm"]["positions"] = outputs["positions"] #torch.Size([8, 1, 120, 14, 3]) + outputs["sm"]["angles"] = outputs["angles"] + outputs["sm"]["unnormalized_angles"] = outputs["unnormalized_angles"] + + + # outputs["unnormalized_angles"] #torch.Size([8, 1, 120, 7, 2]) + # outputs["unnormalized_angles"] + # outputs["angles"] #torch.Size([8, 1, 120, 7, 2]) + # outputs["lddt_head"] #torch.Size([8, 1, 120, 37, 50]) + return outputs + + else: + return self.model(batch) + def _log(self, loss_breakdown, batch, outputs, train=True): phase = "train" if train else "val" for loss_name, indiv_loss in loss_breakdown.items(): + print(f"{phase}/{loss_name}_epoch {indiv_loss}") + self.log( - f"{phase}/{loss_name}", - indiv_loss, - prog_bar=(loss_name == 'loss'), - on_step=train, on_epoch=(not train), logger=True, sync_dist=False, - ) + f"{phase}/{loss_name}", + indiv_loss, + prog_bar=(loss_name == 'loss'), + on_step=train, on_epoch=(not train), logger=True, sync_dist=False, + ) if(train): self.log( @@ -80,12 +146,16 @@ def _log(self, loss_breakdown, batch, outputs, train=True): on_step=False, on_epoch=True, logger=True, sync_dist=False, ) + with torch.no_grad(): + print("running other metrics") other_metrics = self._compute_validation_metrics( batch, outputs, superimposition_metrics=(not train) ) + print("ran other metrics") + for k,v in other_metrics.items(): self.log( @@ -94,32 +164,87 @@ def _log(self, loss_breakdown, batch, outputs, train=True): prog_bar = (k == 'loss'), on_step=False, on_epoch=True, logger=True, sync_dist=False, ) + print(f"{phase}/{k} {torch.mean(v)}") def training_step(self, batch, batch_idx): + + + #batch["all_atom_mask"][0,:92,0,0] # test + #batch["seq_length"] + # batch["true_msa"][0,0,:100,0] + + + + # TODO: + # add violations and other loss back + # figure out why sequence input changes + # quick fix + # print(batch["sequence"].shape) + # if batch["sequence"].shape!=torch.Size([1, 1, 120, 1]): + # return None + + # test = next(self.model.distogram_head.parameters()) + # print("disto params") + # print(test[1,:3]) + # check that model params are being updated + + + # import pdb; pdb.set_trace() + + # for param in self.model.esm.contact_head.regression.parameters(): + # print(param.requires_grad) + # self.model.esm.requires_grad_(True) + # for param in self.model.esm.contact_head.regression.parameters(): + # print(param.requires_grad) + + if(self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) ground_truth = batch.pop('gt_features', None) - # Run the model - outputs = self(batch) + # need to add padding to 256 here + if self.use_esm: + + outputs = self(batch) + + feats = tensor_tree_map(lambda t: t[..., -1], batch) # we don't use recylcing so we just take last batch input + outputs["final_atom_positions"] = atom14_to_atom37(outputs["sm"]["positions"][-1], outputs) + + # batch["sequence"].shape # torch.Size([1, 1, 120]) + # outputs["sm"]["positions"][-1] # torch.Size([1, 120, 14, 3]) + + + else: + outputs = self(batch) + # outputs["sm"]["frames"].shape gives torch.Size([8, 1, 120, 7]) + # outputs["sm"]["sidechain_frames"].shape gives + # outputs["sm"]["positions"].shape gives + + # outputs["lddt_logits"].shape gives torch.Size([1, 120, 50]) + # outputs["masked_msa_logits"] # torch.Size([1, 128, 120, 23]) # we don't care for ESMFOLD + # outputs["experimentally_resolved_logits"].shape gives torch.Size([1, 120, 37]) + # Remove the recycling dimension batch = tensor_tree_map(lambda t: t[..., -1], batch) - if self.is_multimer: - batch = multi_chain_permutation_align(out=outputs, - features=batch, - ground_truth=ground_truth) + # if self.is_multimer: + # batch = multi_chain_permutation_align(out=outputs, + # features=batch, + # ground_truth=ground_truth) # Compute loss - loss, loss_breakdown = self.loss( - outputs, batch, _return_breakdown=True - ) - + loss, loss_breakdown = self.loss(outputs, batch, _return_breakdown=True) + # import pdb; pdb.set_trace() # Log it + # print(loss) + + # TODO: add later self._log(loss_breakdown, batch, outputs) + # import pdb; pdb.set_trace() + return loss def on_before_zero_grad(self, *args, **kwargs): @@ -127,26 +252,47 @@ def on_before_zero_grad(self, *args, **kwargs): def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights - if(self.cached_weights is None): - # model.state_dict() contains references to model weights rather - # than copies. Therefore, we need to clone them before calling - # load_state_dict(). - clone_param = lambda t: t.detach().clone() - self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) - self.model.load_state_dict(self.ema.state_dict()["params"]) + print("validation step called") + print(batch["atom14_gt_positions"].shape) + # batch.update(compute_renamed_ground_truth(batch,outputs["sm"]["positions"][-1],)) + # batch["atom14_gt_positions"].shape gives torch.Size([1, 92, 14, 3]) + # batch["atom14_alt_gt_positions"].shape gives torch.Size([1, 92, 14, 3]) - ground_truth = batch.pop('gt_features', None) + # if(self.cached_weights is None): + # # model.state_dict() contains references to model weights rather + # # than copies. Therefore, we need to clone them before calling + # # load_state_dict(). + # clone_param = lambda t: t.detach().clone() + # self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + # self.model.load_state_dict(self.ema.state_dict()["params"]) + + # import pdb; pdb.set_trace() + # batch["sequence"].shape is torch.Size([1, 1, 120, 1]) different + + + + ground_truth = batch.pop('gt_features', None) + # Run the model - outputs = self(batch) + if self.use_esm: + + outputs = self(batch) + + feats = tensor_tree_map(lambda t: t[..., -1], batch) # we don't use recylcing so we just take last batch input + outputs["final_atom_positions"] = atom14_to_atom37(outputs["sm"]["positions"][-1], outputs) + + else: + outputs = self(batch) + batch = tensor_tree_map(lambda t: t[..., -1], batch) batch["use_clamped_fape"] = 0. - if self.is_multimer: - batch = multi_chain_permutation_align(out=outputs, - features=batch, - ground_truth=ground_truth) + # if self.is_multimer: + # batch = multi_chain_permutation_align(out=outputs, + # features=batch, + # ground_truth=ground_truth) # Compute loss and other metrics _, loss_breakdown = self.loss( @@ -154,11 +300,16 @@ def validation_step(self, batch, batch_idx): ) self._log(loss_breakdown, batch, outputs, train=False) + + print("validation step done") def on_validation_epoch_end(self): # Restore the model weights to normal - self.model.load_state_dict(self.cached_weights) - self.cached_weights = None + + # TODO: look into this + # self.model.load_state_dict(self.cached_weights) + # self.cached_weights = None + return def _compute_validation_metrics(self, batch, @@ -218,6 +369,9 @@ def configure_optimizers(self, learning_rate: float = 1e-3, eps: float = 1e-5, ) -> torch.optim.Adam: + + # learning_rate=0.01 + # Ignored as long as a DeepSpeed optimizer is configured optimizer = torch.optim.Adam( self.model.parameters(), @@ -318,7 +472,7 @@ def main(args): # Loading from pre-trained model sd = {'model.'+k: v for k, v in sd.items()} import_openfold_weights_(model=model_module, state_dict=sd) - logging.info("Successfully loaded model weights...") + # logging.info("Successfully loaded model weights...") else: # Loads a checkpoint to start from a specific time step if os.path.isdir(args.resume_from_ckpt): @@ -327,7 +481,7 @@ def main(args): sd = torch.load(args.resume_from_ckpt) last_global_step = int(sd['global_step']) model_module.resume_last_lr_step(last_global_step) - logging.info("Successfully loaded last lr step...") + # logging.info("Successfully loaded last lr step...") if args.resume_from_jax_params: model_module.load_from_jax(args.resume_from_jax_params) @@ -409,6 +563,33 @@ def main(args): **{"entity": args.wandb_entity} ) loggers.append(wdb_logger) + else: + # from pytorch_lightning import Trainer + from pytorch_lightning.loggers import NeptuneLogger + + # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class + # We are using an api_key for the anonymous user "neptuner" but you can use your own. + neptune_logger = NeptuneLogger( + api_key='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI0NGQwNTE3My02ZTZlLTRiYzAtODBkZi1mZjA5ZWQ2ZTNlYzYifQ', + project_name='InstaDeep/vhh', + # # experiment_name='default', # Optional, + # # params={'max_epochs': 10}, # Optional, + # # tags=['pytorch-lightning', 'mlp'] # Optional, + ) + # trainer = Trainer(max_epochs=10, logger=neptune_logger) + + # neptune_callback = NeptuneCallback( + # project="InstaDeep/vhh", + # api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI0NGQwNTE3My02ZTZlLTRiYzAtODBkZi1mZjA5ZWQ2ZTNlYzYifQ==", + # ) + # loggers.append(neptune_logger) + + + import logging + + logger = logging.getLogger("root_experiment") + logger.setLevel(logging.DEBUG) + loggers.append(logger) cluster_environment = MPIEnvironment() if args.mpi_plugin else None if(args.deepspeed_config_path is not None): @@ -435,7 +616,7 @@ def main(args): trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} trainer_args.update({ 'default_root_dir': args.output_dir, - 'strategy': strategy, + 'strategy': "auto", #changed here 'callbacks': callbacks, 'logger': loggers, }) @@ -450,9 +631,22 @@ def main(args): trainer.fit( model_module, datamodule=data_module, - ckpt_path=ckpt_path, + ckpt_path=ckpt_path ) + outputdir = args.output_dir + # create a folder in outputdir based on date + # save the model in that folder + import datetime + date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + outputdir = os.path.join(outputdir, date) + os.makedirs(outputdir, exist_ok=True) + + model_module.model.save_pretrained(os.path.join(outputdir, "last_checkpoint"), from_pt=True) + # del model_module # to free up memory + + # model_fold = EsmForProteinFolding.from_pretrained(os.path.join(outputdir, "last_checkpoint"), low_cpu_mem_usage=False) + def bool_type(bool_str: str): bool_str_lower = bool_str.lower()