Skip to content

esmfold training adapation of openfold #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added confidence_threshold
Empty file.
14 changes: 7 additions & 7 deletions openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def model_config(
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_recycling_iters": 3,
"max_recycling_iters": 0, # changed from 3
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
Expand Down Expand Up @@ -420,13 +420,13 @@ def model_config(
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15,
"masked_msa_replace_fraction": 0.0, #from 0.15
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"crop": True,
"crop_size": 150,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": True,
Expand All @@ -436,14 +436,14 @@ def model_config(
"fixed_size": True,
"subsample_templates": True,
"block_delete_msa": True,
"masked_msa_replace_fraction": 0.15,
"masked_msa_replace_fraction": 0.0, #from 0.15
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"crop_size": 150, # TODO: change back to 256 ?
"spatial_crop_prob": 0.,
"interface_threshold": None,
"supervised": True,
Expand Down Expand Up @@ -792,7 +792,7 @@ def model_config(
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20, # For training, value is 3
"max_recycling_iters": 0, # changed from 20 # For training, value is 3
"unsupervised_features": [
"aatype",
"residue_index",
Expand Down
266 changes: 264 additions & 2 deletions openfold/data/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from openfold.utils.tensor_utils import (
tensor_tree_map,
)

import pandas as pd

class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
Expand Down Expand Up @@ -87,6 +87,9 @@ def __init__(self,
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir

print("dataloader getting reloaded !")


self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
Expand All @@ -107,17 +110,46 @@ def __init__(self,
if mode not in valid_modes:
raise ValueError(f'mode must be one of {valid_modes}')

self.is_esm = True

self.df = None
if mode == "train":
# load this csv /home/j-quentin/openfold_small_data/df_train.csv in a dataframe
self.df = pd.read_csv("/home/j-quentin/openfold_small_data/df_train.csv", index_col=0)
# self.df = pd.read_csv("/home/j-quentin/openfold_small_data/df_val.csv", index_col=0)

elif mode == "eval":
self.df = pd.read_csv("/home/j-quentin/openfold_small_data/df_val.csv", index_col=0)
else:
raise NotImplementedError("mode not implemented")

if template_release_dates_cache_path is None:
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)

# import pdb;pdb.set_trace()

if alignment_index is not None:
self._chain_ids = list(alignment_index.keys())
else:
self._chain_ids = list(os.listdir(alignment_dir))

# import pdb; pdb.set_trace()
filtered_ids = self.df.index.tolist()
self._chain_ids = [k for k in self._chain_ids if k in filtered_ids]
print(f"Loaded {len(self._chain_ids)} chains in {mode} mode")

# we shuffle to sample afterwards
# import random
# random.shuffle(self._chain_ids)

# import pdb; pdb.set_trace()

# self._chain_ids = self._chain_ids[:500]
# print(self._chain_ids)

if filter_path is not None:
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
Expand Down Expand Up @@ -173,6 +205,101 @@ def __init__(self,
if not self._output_raw:
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

from transformers import AutoTokenizer, EsmForProteinFolding, BitsAndBytesConfig
self.tokenizer_fold = AutoTokenizer.from_pretrained("facebook/esmfold_v1")


# import pdb; pdb.set_trace()
# idx=0
# name = self.idx_to_chain_id(idx)
# print(name)
# alignment_dir = os.path.join(self.alignment_dir, name)

# alignment_index = None
# if self.alignment_index is not None:
# alignment_dir = self.alignment_dir
# alignment_index = self.alignment_index[name]

# if self.mode == 'train' or self.mode == 'eval':
# spl = name.rsplit('_', 1)
# if len(spl) == 2:
# file_id, chain_id = spl
# else:
# file_id, = spl
# chain_id = None

# path = os.path.join(self.data_dir, file_id)
# if self._structure_index is not None:
# structure_index_entry = self._structure_index[name]
# assert (len(structure_index_entry["files"]) == 1)
# filename, _, _ = structure_index_entry["files"][0]
# ext = os.path.splitext(filename)[1]
# else:
# ext = None
# for e in self.supported_exts:
# if os.path.exists(path + e):
# ext = e
# break

# if ext is None:
# raise ValueError("Invalid file type")

# path += ext
# if ext == ".cif":
# data = self._parse_mmcif(
# path, file_id, chain_id, alignment_dir, alignment_index,
# )
# elif ext == ".core":
# data = self.data_pipeline.process_core(
# path, alignment_dir, alignment_index,
# seqemb_mode=self.config.seqemb_mode.enabled,
# )
# elif ext == ".pdb":
# structure_index = None
# if self._structure_index is not None:
# structure_index = self._structure_index[name]
# data = self.data_pipeline.process_pdb(
# pdb_path=path,
# alignment_dir=alignment_dir,
# is_distillation=self.treat_pdb_as_distillation,
# chain_id=chain_id,
# alignment_index=alignment_index,
# _structure_index=structure_index,
# seqemb_mode=self.config.seqemb_mode.enabled,
# )
# else:
# raise ValueError("Extension branch missing")
# else:
# path = os.path.join(name, name + ".fasta")
# data = self.data_pipeline.process_fasta(
# fasta_path=path,
# alignment_dir=alignment_dir,
# alignment_index=alignment_index,
# seqemb_mode=self.config.seqemb_mode.enabled,
# )

# if self._output_raw:
# return data

# feats = self.feature_pipeline.process_features(
# data, self.mode
# )

# feats["batch_idx"] = torch.tensor(
# [idx for _ in range(feats["aatype"].shape[-1])],
# dtype=torch.int64,
# device=feats["aatype"].device)

# encoding = 'utf-8'
# seq = str( data["sequence"][0], encoding)
# tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding=True)['input_ids']
# tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, truncation=True, max_length=120)['input_ids']

# factor = feats["use_clamped_fape"].shape[0]
# feats["sequence"] = tokenized_input.unsqueeze(-1).repeat(1,1,factor)



def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
Expand Down Expand Up @@ -204,8 +331,109 @@ def chain_id_to_idx(self, chain_id):
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]

def custom_getitem(self, idx):
name = self.idx_to_chain_id(idx)
print("starting get item")
print(name)
alignment_dir = os.path.join(self.alignment_dir, name)

alignment_index = None
if self.alignment_index is not None:
alignment_dir = self.alignment_dir
alignment_index = self.alignment_index[name]

if self.mode == 'train' or self.mode == 'eval':
spl = name.rsplit('_', 1)
if len(spl) == 2:
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None

path = os.path.join(self.data_dir, file_id)
if self._structure_index is not None:
structure_index_entry = self._structure_index[name]
assert (len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if os.path.exists(path + e):
ext = e
break

if ext is None:
raise ValueError("Invalid file type")

path += ext
if ext == ".cif":
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif ext == ".core":
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
)
elif ext == ".pdb":
structure_index = None
if self._structure_index is not None:
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=structure_index,
seqemb_mode=self.config.seqemb_mode.enabled,
)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
)

if self._output_raw:
return data

feats = self.feature_pipeline.process_features(
data, self.mode
)


feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)

# new
encoding = 'utf-8'
seq = str( data["sequence"][0], encoding)
print(len(seq))
# tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=120)['input_ids']
# tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=120, truncation=True)['input_ids']

# # quick fix to bypass recycling
# factor = feats["use_clamped_fape"].shape[0]
# feats["sequence"] = tokenized_input.unsqueeze(-1).repeat(1,1,factor)

# print("finishing get item")

return len(seq)


def __getitem__(self, idx):

name = self.idx_to_chain_id(idx)
print("starting get item")
print(name)
alignment_dir = os.path.join(self.alignment_dir, name)

alignment_index = None
Expand Down Expand Up @@ -277,12 +505,43 @@ def __getitem__(self, idx):
feats = self.feature_pipeline.process_features(
data, self.mode
)


feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)

# sequence = self.df.iloc[0]["seq"]
# print(f"overuled get item idx=0")
# print(sequence)
# modeling worked on : TRDQNGTWEMESNENFEGYMKALDIDFATRKIAVRLTQTLVIDQDGDNFKVKTTSTFFNYDVDFTVGVEFDEYTKSLDNRHVKALVTWEGDVLVCVQKGEKENRGWKKWIEGDKLYLELTCGDQVCRQVFKKK

if self.is_esm:
# new
encoding = 'utf-8'
seq = str( data["sequence"][0], encoding)
print(seq)
# print(f"overuled get item idx=0")
# seq = sequence


if self.df[self.df.index==name].iloc[0]["seq"] != seq:
print("wrong sequence!")
# import pdb;pdb.set_trace()
else:
print("right sequence!")

print(len(seq))
# tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=120)['input_ids']
tokenized_input = self.tokenizer_fold(seq, return_tensors="pt", add_special_tokens=False, padding='max_length', max_length=150, truncation=True)['input_ids']

# quick fix to bypass recycling
factor = feats["use_clamped_fape"].shape[0]
feats["sequence"] = tokenized_input.unsqueeze(-1).repeat(1,1,factor)

print("finishing get item")

return feats

def __len__(self):
Expand Down Expand Up @@ -450,6 +709,7 @@ def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx]

def __getitem__(self, idx):

mmcif_id = self.idx_to_mmcif_id(idx)

alignment_index = None
Expand Down Expand Up @@ -837,10 +1097,12 @@ def _add_batch_properties(self, batch):

def __iter__(self):
it = super().__iter__()

print("starting iter OpenFoldDataLoader")
def _batch_prop_gen(iterator):
print("called _batch_prop_gen")
for batch in iterator:
yield self._add_batch_properties(batch)
print("stopping iter OpenFoldDataLoader")

return _batch_prop_gen(it)

Expand Down
Loading