diff --git a/.gitignore b/.gitignore index bbd6044..e039886 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,12 @@ dev_iDeLUCS/* .DS_Store Results/ */__pycache__/* +idelucs/kmers.c +idelucs/kmers.cpython-* +*.ipynb* +idelucs.egg-info/* +ALL_RESULTS.tsv +build/ +*.tfrecord +s_50* +Example/*_shuffled.fas \ No newline at end of file diff --git a/idelucs/PytorchUtils.py b/idelucs/PytorchUtils.py index 92c5ca5..fc81b12 100644 --- a/idelucs/PytorchUtils.py +++ b/idelucs/PytorchUtils.py @@ -35,13 +35,19 @@ def __init__(self, n_input, n_output): super(NetLinear, self).__init__() self.n_input = n_input self.layers = nn.Sequential( - nn.Linear(n_input, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(512, 64) ) + self.fine_tune_layer = nn.Sequential( + nn.ReLU(), + nn.Dropout(p=0.5), + nn.Linear(64, n_output), + nn.Softmax(dim=1) + ) + self.classifier = nn.Sequential( nn.ReLU(), nn.Dropout(p=0.5), @@ -54,6 +60,12 @@ def forward(self, x): latent = self.layers(x) out = self.classifier(latent) return out, latent + + def fine_tune(self, x): + x = x.view(-1, self.n_input) + x = self.layers(x) + out = self.fine_tune_layer(x) + return out class myDataset(Dataset): diff --git a/idelucs/__main__.py b/idelucs/__main__.py index 3877aeb..c9d1cb3 100644 --- a/idelucs/__main__.py +++ b/idelucs/__main__.py @@ -8,7 +8,7 @@ from idelucs.utils import SummaryFasta, plot_confusion_matrix, \ - label_features, compute_results + label_features, compute_results, cluster_acc from idelucs import models @@ -110,6 +110,7 @@ def run(args): model_min_loss=np.inf for i in range(args['n_epochs']): + print(f"EPOCH {i}") loss = model.contrastive_training_epoch() model_min_loss = min(model_min_loss, loss) model_loss.append(loss) @@ -156,16 +157,43 @@ def run(args): if args['GT_file'] != None: unique_labels = list(np.unique(model.GT)) numClasses = len(unique_labels) + predictions = [] y = np.array(list(map(lambda x: unique_labels.index(x), model.GT))) + np.set_printoptions(threshold=np.inf) + results, ind = compute_results(y_pred, latent, y) + print("Without fine-tuning result:", results) + if args["fine_tune"]: + ind, acc = cluster_acc(y, y_pred) + + # print("Y-pred: ", y_pred) + for i in range(10): + model.fine_tune(ind) + y_pred, probabilities, latent = model.predict() + y_pred = y_pred.astype(np.int32) + + d = {} + count = 0 + for i in range(y_pred.shape[0]): + if y_pred[i] in d: + y_pred[i] = d[y_pred[i]] + else: + d[y_pred[i]] = count + y_pred[i] = count + count += 1 + predictions.append(y_pred) + y_pred, probabilities = label_features(np.array(predictions), args['n_clusters']) + results, ind = compute_results(y_pred, latent, y) + # print(results, ind, y, y_pred) d = {} for i, j in ind: d[i] = j + # d maps predicted to truth if -1 in y_pred: d[-1] = 0 - + w = np.zeros((numClasses, max(max(y_pred) + 1, max(y) + 1)), dtype=np.int64) clustered = np.zeros_like(y, dtype=bool) for i in range(y.shape[0]): @@ -190,6 +218,8 @@ def run(args): #clustered = (probabilities >= 0.9) clustered = (probabilities >= 0.0) + + print(results) sys.stdout.write(f"\r........ Saving Results ..............") sys.stdout.flush() @@ -258,6 +288,8 @@ def run(args): def main(): parser= argparse.ArgumentParser() parser.add_argument('--sequence_file', action='store',type=str) + parser.add_argument('--fine_tune', action='store',type=bool, default=False) + parser.add_argument('--n_clusters', action='store',type=int,default=0, help='Expected or maximum number of clusters to find. \n' 'It should be equal or greater than n_true_clusters \n' diff --git a/idelucs/models.py b/idelucs/models.py index e8af9a4..08ef497 100644 --- a/idelucs/models.py +++ b/idelucs/models.py @@ -12,7 +12,7 @@ from .LossFunctions import IID_loss, info_nce_loss from .PytorchUtils import NetLinear, myNet from .ResNet import ResNet18 -from .utils import SequenceDataset, create_dataloader +from .utils import SequenceDataset, create_dataloader, generate_dataloader, generate_dataloader_tfrecord, generate_finetune_dataloader # Random Seeds for reproducibility. torch.manual_seed(0) @@ -48,7 +48,6 @@ def __init__(self, args: dict): self.sequence_file = args['sequence_file'] self.GT_file = args['GT_file'] - self.n_clusters = args['n_clusters'] self.k = args['k'] @@ -57,7 +56,7 @@ def __init__(self, args: dict): self.net = NetLinear(self.n_features, args['n_clusters']) self.reduce = False - elif args['model_size'] == 'small': + elif args['model_size'] == 'smal l': if self.k % 2 == 0: n_in = (4**self.k + 4**(self.k//2))//2 else: n_in = (4**self.k)//2 #d = {4: 135, 5: 511, 6: 2079} @@ -102,18 +101,74 @@ def build_dataloader(self): #Data Files data_path = self.sequence_file GT_file = self.GT_file - - self.dataloader = create_dataloader(data_path, + + if self.sequence_file.endswith('.tfrecord'): + self.dataloader = generate_dataloader_tfrecord(data_path, self.n_mimics, k=self.k, batch_size=self.batch_sz, - GT_file=GT_file, reduce=self.reduce) + else: + self.dataloader = create_dataloader(data_path, + self.n_mimics, + k=self.k, + batch_size=self.batch_sz, + GT_file=GT_file, + reduce=self.reduce) + + def fine_tune(self, ind): + # maps true labels to new labels + d = {} + for j, i in ind: + d[i] = j + self.net.train() + dataloader = generate_finetune_dataloader( + d, + self.GT_file, + self.sequence_file, + self.k, + self.batch_sz, + self.reduce + ) + loss_func = nn.CrossEntropyLoss() + for _, sample_batched in enumerate(dataloader): + sample = sample_batched['true'].view(-1, 1, self.n_features).type(dtype) + labels = torch.nn.functional.one_hot(sample_batched['label'], num_classes=self.n_clusters).type(dtype) + # print(labels) + + self.optimizer.zero_grad() + + pred_result = self.net.fine_tune(sample) + # print(pred_result) + + loss = loss_func(pred_result, labels) + # print(loss) + loss.backward() + self.optimizer.step() + + if self.schedule == 'Plateau': + self.scheduler.step(running_loss) + elif self.schedule == 'Triangle': + self.scheduler.step() + def contrastive_training_epoch(self): self.net.train() running_loss = 0.0 - + # dataloader = generate_dataloader( + # self.sequence_file, + # self.n_mimics, + # self.k, + # self.batch_sz, + # self.reduce + # ) + # dataloader = generate_dataloader_tfrecord( + # self.sequence_file, + # self.n_mimics, + # self.k, + # self.batch_sz, + # self.reduce + # ) for i_batch, sample_batched in enumerate(self.dataloader): sample = sample_batched['true'].view(-1, 1, self.n_features).type(dtype) modified_sample = sample_batched['modified'].view(-1, 1, self.n_features).type(dtype) @@ -145,12 +200,19 @@ def contrastive_training_epoch(self): def predict(self, data=None): n_features = self.n_features + # if self.sequence_file.endswith('.tfrecord'): + # test_dataloader = generate_dataloader_tfrecord(self.sequence_file, + # self.n_mimics, + # k=self.k, + # batch_size=self.batch_sz, + # reduce=self.reduce) + # else: test_dataset = SequenceDataset(self.sequence_file, k=self.k, transform=None, GT_file=self.GT_file, reduce=self.reduce) test_dataloader = DataLoader(test_dataset, - batch_size=self.batch_sz, - shuffle=False, - num_workers=0, - drop_last=False) + batch_size=self.batch_sz, + shuffle=False, + num_workers=0, + drop_last=False) y_pred = [] probabilities = [] latent = [] diff --git a/idelucs/utils.py b/idelucs/utils.py index 7ae14be..b9c3125 100644 --- a/idelucs/utils.py +++ b/idelucs/utils.py @@ -13,11 +13,12 @@ import torch from torch.utils.data import Dataset, DataLoader -from torch.utils.data import DataLoader from scipy.optimize import linear_sum_assignment from sklearn.preprocessing import StandardScaler +from tfrecord.torch.dataset import TFRecordDataset + import matplotlib.pyplot as plt from colorsys import hsv_to_rgb @@ -145,43 +146,57 @@ def SummaryFasta(fname, GT_file=None): GT_dict = dict(zip(df.sequence_id, df.cluster_id)) cluster_dis = df['cluster_id'].value_counts().to_dict() - for line in open(fname, "rb"): + if fname.endswith('.tfrecord'): + reader = TFRecordDataset(fname, None) + for cur_entry in reader: + seq_id, seq = cur_entry["sequence_name"].decode(), cur_entry["sequence"] + if (GT_file and not seq_id in GT_dict): + raise ValueError('Check GT for sequence {}'.format(seq_id)) + seq = check_sequence(seq_id, seq) + names.append(seq_id) + lengths.append(len(seq)) + if GT_file: + ground_truth.append(GT_dict[seq_id]) - if line.startswith(b'#'): - pass + else: + reader = open(fname, "rb") + for line in reader: - elif line.startswith(b'>'): - if seq_id != "": - seq = bytearray().join(lines) + if line.startswith(b'#'): + pass - if (GT_file and not seq_id in GT_dict): - raise ValueError('Check GT for sequence {}'.format(seq_id)) + elif line.startswith(b'>'): + if seq_id != "": + seq = bytearray().join(lines) - seq = check_sequence(seq_id, seq) - names.append(seq_id) - lengths.append(len(seq)) + if (GT_file and not seq_id in GT_dict): + raise ValueError('Check GT for sequence {}'.format(seq_id)) - if GT_file: - ground_truth.append(GT_dict[seq_id]) + seq = check_sequence(seq_id, seq) + names.append(seq_id) + lengths.append(len(seq)) - lines = [] - seq_id = line[1:-1].decode() # Modify this according to your labels. + if GT_file: + ground_truth.append(GT_dict[seq_id]) + + lines = [] + seq_id = line[1:-1].decode() # Modify this according to your labels. + + seq_id = line[1:-1].decode() - seq_id = line[1:-1].decode() - - else: - lines += [line.strip()] - - if (GT_file and not seq_id in GT_dict): - raise ValueError('Check GT for sequence {}'.format(seq_id)) + else: + lines += [line.strip()] + + if (GT_file and not seq_id in GT_dict): + raise ValueError('Check GT for sequence {}'.format(seq_id)) - seq = bytearray().join(lines) - seq = check_sequence(seq_id, seq) - names.append(seq_id) - lengths.append(len(seq)) + seq = bytearray().join(lines) + seq = check_sequence(seq_id, seq) + names.append(seq_id) + lengths.append(len(seq)) - if GT_file: - ground_truth.append(GT_dict[seq_id]) + if GT_file: + ground_truth.append(GT_dict[seq_id]) return names, lengths, ground_truth, cluster_dis @@ -224,51 +239,59 @@ def kmersFasta(fname, k=6, transform=None, reduce=False): seq_id = "" names, kmers = [], [] - for line in open(fname, "rb"): - if line.startswith(b'#'): - pass + if fname.endswith('.tfrecord'): + reader = TFRecordDataset(fname, None) + for cur_entry in reader: + seq_id, seq = cur_entry["sequence_name"].decode(), bytearray(cur_entry["sequence"]) + names.append(seq_id) + if transform: + transform(seq) + counts = np.ones(4**k, dtype=np.int32) + kmer_counts(seq, k, counts) + if reduce: + counts = kmer_rev_comp(counts,k) + kmers.append(counts / np.sum(counts)) - elif line.startswith(b'>'): - if seq_id != "": - seq = bytearray().join(lines) - names.append(seq_id) - - if transform: - transform(seq) - - counts = np.ones(4**k, dtype=np.int32) - kmer_counts(seq, k, counts) - #cgr(seq, k, counts) - - if reduce: - counts = kmer_rev_comp(counts,k) - - - kmers.append(counts / np.sum(counts)) - - lines = [] - seq_id = line[1:-1].decode() # Modify this according to your labels. - seq_id = line[1:-1].decode() - - else: - lines += [line.strip()] - - seq = bytearray().join(lines) - names.append(seq_id) - if transform: - transform(seq) - - counts = np.ones(4**k, dtype=np.int32) - kmer_counts(seq, k, counts) - #cgr(seq, k, counts) - if reduce: - counts = kmer_rev_comp(counts,k) - kmers.append(counts / np.sum(counts)) + else: + for line in open(fname, "rb"): + if line.startswith(b'#'): + pass + + elif line.startswith(b'>'): + if seq_id != "": + seq = bytearray().join(lines) + names.append(seq_id) + + if transform: + transform(seq) + + counts = np.ones(4**k, dtype=np.int32) + kmer_counts(seq, k, counts) + #cgr(seq, k, counts) + + if reduce: + counts = kmer_rev_comp(counts,k) + + + kmers.append(counts / np.sum(counts)) + lines = [] + seq_id = line[1:-1].decode() # Modify this according to your labels. + seq_id = line[1:-1].decode() + + else: + lines += [line.strip()] - #if reduce: - # K_file = np.load(open(f'kernels/kernel{k}.npz','rb')) - # KERNEL = K_file['arr_0'] - # return names, np.dot(np.array(kmers), KERNEL) + seq = bytearray().join(lines) + names.append(seq_id) + if transform: + transform(seq) + + counts = np.ones(4**k, dtype=np.int32) + kmer_counts(seq, k, counts) + #cgr(seq, k, counts) + if reduce: + counts = kmer_rev_comp(counts,k) + kmers.append(counts / np.sum(counts)) return names, np.array(kmers) @@ -363,6 +386,28 @@ def AugmentFasta(sequence_file, n_mimics, k=6, reduce=False): return x_train +class FinetuneDataset(Dataset): + def __init__(self, true_to_generated_labels, gt_file, sequence_file, k=6, transform=None, reduce=None) -> None: + _, _, labels, _ = SummaryFasta(sequence_file, gt_file) + _, self.data = kmersFasta(sequence_file, k, transform, reduce=reduce) + + unique_labels = list(np.unique(labels)) + self.labels = np.array(list(map(lambda x: true_to_generated_labels[unique_labels.index(x)], labels))) + + # print("LABELS: " ,self.labels) + + def __len__(self): + return int(self.data.shape[0]/10) # fine tune using 1/10th + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + sample = {'true': self.data[idx], 'label': self.labels[idx]} + return sample + + + class AugmentedDataset(Dataset): """ Dataset creation directly from fasta file. @@ -618,3 +663,137 @@ def compute_results(y_pred, data, y_true=None): return d, None +from torch.utils.data import IterableDataset +class NewAugmentedDataset(IterableDataset): + def __init__(self, sequence_file, n_mimics, k=6, reduce=False): + self.file = open(sequence_file, 'r+b') + self.n_mimics = n_mimics + self.reduce = reduce + self.k = k + self.scaler = self.generate_scaler() + + def generate_kmer(self, sequence, k, transform=None, reduce=False, scale=True): + if transform: + transform(sequence) + counts = np.ones(4**k, dtype=np.int32) + kmer_counts(sequence, k, counts) + + if reduce: + counts = kmer_rev_comp(counts, k) + + if scale: + return self.scaler.transform([counts / np.sum(counts)])[0] + return counts / np.sum(counts) + + def generate_scaler(self): + scaler = StandardScaler() + for line in self.file: + if line.startswith(b'>'): + continue + normal = self.generate_kmer(bytearray(line), self.k, transition_transversion(1e-2, 0.5e-2), self.reduce, scale=False) + scaler.partial_fit([normal]) + self.file.seek(0) + return scaler + + def __iter__(self): + for line in self.file: + if line.startswith(b'>'): + continue + # generating the normal + normal = self.generate_kmer(bytearray(line), self.k, transition_transversion(1e-2, 0.5e-2), self.reduce) + # mutated #1 + mutated1 = self.generate_kmer(bytearray(line), self.k, transition(1e-2), self.reduce) + # mutated #2 + mutated2 = self.generate_kmer(bytearray(line), self.k, transversion(0.5e-2), self.reduce) + yield {"true": normal, "modified": mutated1} + yield {"true": normal, "modified": mutated2} + + for _ in range(self.n_mimics - 2): + mutated = self.generate_kmer(bytearray(line), self.k, Random_N(20), self.reduce) + yield {"true": normal, "modified": mutated} + +import torch +class ShuffleDataset(IterableDataset): + def __init__(self, dataset, buffer_size): + super().__init__() + self.dataset = dataset + self.buffer_size = buffer_size + + def yield_from_shuffled(self, arr): + for index in torch.randperm(len(arr)).tolist(): + yield arr[index] + + def __iter__(self): + count = 0 + shufbuf = [] + for entry in self.dataset: + count += 1 + if len(shufbuf) < self.buffer_size: + shufbuf.append(entry) + else: + yield from self.yield_from_shuffled(shufbuf) + shufbuf = [] + yield from self.yield_from_shuffled(shufbuf) + +class TfrecordAugmentedDataset(IterableDataset): + def __init__(self, sequence_file, n_mimics, k=6, reduce=False): + self.file = TFRecordDataset(sequence_file, None, shuffle_queue_size=1024) + self.n_mimics = n_mimics + self.reduce = reduce + self.k = k + self.scaler = self.generate_scaler() + + def generate_kmer(self, sequence, k, transform=None, reduce=False, scale=True): + if transform: + transform(sequence) + counts = np.ones(4**k, dtype=np.int32) + kmer_counts(sequence, k, counts) + + if reduce: + counts = kmer_rev_comp(counts, k) + + if scale: + return self.scaler.transform([counts / np.sum(counts)])[0] + return counts / np.sum(counts) + + def generate_scaler(self): + print("Building the scaler for tfrecord file") + scaler = StandardScaler() + for entry in self.file: + seq_name, sequence = entry["sequence_name"], entry["sequence"] + normal = self.generate_kmer(bytearray(sequence), self.k, transition_transversion(1e-2, 0.5e-2), self.reduce, scale=False) + scaler.partial_fit([normal]) + return scaler + + def __iter__(self): + cur_id = "" + for entry in self.file: + seq_name, sequence = entry["sequence_name"], entry["sequence"] + # generating the normal + normal = self.generate_kmer(bytearray(sequence), self.k, transition_transversion(1e-2, 0.5e-2), self.reduce) + # mutated #1 + mutated1 = self.generate_kmer(bytearray(sequence), self.k, transition(1e-2), self.reduce) + # mutated #2 + mutated2 = self.generate_kmer(bytearray(sequence), self.k, transversion(0.5e-2), self.reduce) + yield {"true": normal, "modified": mutated1} + yield {"true": normal, "modified": mutated2} + + for _ in range(self.n_mimics-2): + mutated = self.generate_kmer(bytearray(sequence), self.k, Random_N(20), self.reduce) + yield {"true": normal, "modified": mutated} + + +def generate_dataloader_tfrecord(data_path, n_mimics, k=6, batch_size=512, reduce=False): + dataset = TfrecordAugmentedDataset(data_path, n_mimics, k, reduce) + dataloader = DataLoader(dataset, batch_size=batch_size) + return dataloader + +def generate_dataloader(data_path, n_mimics, k=6, batch_size=512, reduce=False): + dataset = ShuffleDataset(NewAugmentedDataset(data_path, n_mimics, k, reduce), batch_size * 8) + dataloader = DataLoader(dataset, batch_size=batch_size) + return dataloader + +def generate_finetune_dataloader(true_to_generated_labels, gt_file, sequence_file, k=6, batch_size=512, reduce=False): + dataset = FinetuneDataset(true_to_generated_labels, gt_file, sequence_file, k, transform=None, reduce=reduce) + dataloader = DataLoader(dataset, batch_size=batch_size) + return dataloader \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1858c6d..fc09981 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,15 +2,19 @@ name = "idelucs" dynamic = ["version"] dependencies = [ - "numpy==1.23", - "torch==2.0", + "numpy==1.23.0", + "torch==1.13.1", "cython", - "matplotlib==3.7", - "pandas==2.0.1", - "scikit-learn==1.2", - "scipy==1.10", + "matplotlib==3.6.1", + "pandas==1.5.3", + "scikit-learn==1.1.2", + "scipy==1.10.1", "umap-learn==0.5.3", - "hdbscan==0.8.29" + "tensorflow==2.13.0rc0", + "protobuf==3.20.3", + "tfrecord==1.14.3", + "hdbscan==0.8.28", + "joblib==1.1.0" ] requires-python = "<3.12,>=3.9.0" scripts = {idelucs = "idelucs.__main__:main"} @@ -32,14 +36,20 @@ version = {attr = "idelucs.__version__"} readme = {file = "README.md"} [build-system] -requires = ["setuptools", "wheel", "Cython", "numpy", - "numpy==1.23", - "torch==2.0", - "cython", - "matplotlib==3.7", - "pandas==2.0.1", - "scikit-learn==1.2", - "scipy==1.10", +requires = ["setuptools", "wheel", "Cython", + "numpy==1.23.0", + "torch==1.13.1", + "cython==0.29.34", + "matplotlib==3.6.1", + "pandas==1.5.3", + "scikit-learn==1.1.2", + "scipy==1.10.1", "umap-learn==0.5.3", - "hdbscan==0.8.29"] + "tensorflow==2.13.0rc0", + "protobuf==3.20.3", + "tfrecord==1.14.3", + "hdbscan==0.8.28", + "joblib==1.1.0"] build-backend = "setuptools.build_meta" + +