From bb2891dfdaf54f92395f7250d2246d8ebe68b923 Mon Sep 17 00:00:00 2001 From: Robrecht Cannoodt Date: Mon, 26 Aug 2024 23:31:10 +0200 Subject: [PATCH 1/3] add novel Co-authored-by: Kai Waldrant --- src/methods/novel/helper_functions.py | 246 ++++++++++++++++++++ src/methods/novel/predict/config.vsh.yaml | 26 +++ src/methods/novel/predict/script.py | 118 ++++++++++ src/methods/novel/run/config.vsh.yaml | 23 ++ src/methods/novel/run/main.nf | 25 ++ src/methods/novel/train/config.vsh.yaml | 32 +++ src/methods/novel/train/script.py | 147 ++++++++++++ src/workflows/run_benchmark/config.vsh.yaml | 1 + src/workflows/run_benchmark/main.nf | 3 +- 9 files changed, 620 insertions(+), 1 deletion(-) create mode 100644 src/methods/novel/helper_functions.py create mode 100644 src/methods/novel/predict/config.vsh.yaml create mode 100644 src/methods/novel/predict/script.py create mode 100644 src/methods/novel/run/config.vsh.yaml create mode 100644 src/methods/novel/run/main.nf create mode 100644 src/methods/novel/train/config.vsh.yaml create mode 100644 src/methods/novel/train/script.py diff --git a/src/methods/novel/helper_functions.py b/src/methods/novel/helper_functions.py new file mode 100644 index 0000000..2696c2f --- /dev/null +++ b/src/methods/novel/helper_functions.py @@ -0,0 +1,246 @@ +import torch + +from torch import nn +import torch.nn.functional as F + +from torch.utils.data import Dataset + +from typing import Optional + +import anndata +import numpy as np +import pandas as pd +import scipy.sparse +import sklearn.decomposition +import sklearn.feature_extraction.text +import sklearn.preprocessing +import sklearn.neighbors +import sklearn.utils.extmath + +class tfidfTransformer(): + def __init__(self): + self.idf = None + self.fitted = False + + def fit(self, X): + self.idf = X.shape[0] / X.sum(axis=0) + self.fitted = True + + def transform(self, X): + if not self.fitted: + raise RuntimeError('Transformer was not fitted on any data') + if scipy.sparse.issparse(X): + tf = X.multiply(1 / X.sum(axis=1)) + return tf.multiply(self.idf) + else: + tf = X / X.sum(axis=1, keepdims=True) + return tf * self.idf + + def fit_transform(self, X): + self.fit(X) + return self.transform(X) + +class lsiTransformer(): + def __init__(self, + n_components: int = 20, + use_highly_variable = None + ): + self.n_components = n_components + self.use_highly_variable = use_highly_variable + self.tfidfTransformer = tfidfTransformer() + self.normalizer = sklearn.preprocessing.Normalizer(norm="l1") + self.pcaTransformer = sklearn.decomposition.TruncatedSVD(n_components = self.n_components, random_state=777) + # self.lsi_mean = None + # self.lsi_std = None + self.fitted = None + + def fit(self, adata: anndata.AnnData): + if self.use_highly_variable is None: + self.use_highly_variable = "hvg" in adata.var + adata_use = adata[:, adata.var["hvg"]] if self.use_highly_variable else adata + X = self.tfidfTransformer.fit_transform(adata_use.X) + X_norm = self.normalizer.fit_transform(X) + X_norm = np.log1p(X_norm * 1e4) + X_lsi = self.pcaTransformer.fit_transform(X_norm) + # self.lsi_mean = X_lsi.mean(axis=1, keepdims=True) + # self.lsi_std = X_lsi.std(axis=1, ddof=1, keepdims=True) + self.fitted = True + + def transform(self, adata): + if not self.fitted: + raise RuntimeError('Transformer was not fitted on any data') + adata_use = adata[:, adata.var["hvg"]] if self.use_highly_variable else adata + X = self.tfidfTransformer.transform(adata_use.X) + X_norm = self.normalizer.transform(X) + X_norm = np.log1p(X_norm * 1e4) + X_lsi = self.pcaTransformer.transform(X_norm) + X_lsi -= X_lsi.mean(axis=1, keepdims=True) + X_lsi /= X_lsi.std(axis=1, ddof=1, keepdims=True) + lsi_df = pd.DataFrame(X_lsi, index = adata_use.obs_names) + return lsi_df + + def fit_transform(self, adata): + self.fit(adata) + return self.transform(adata) + +class ModalityMatchingDataset(Dataset): + def __init__( + self, df_modality1, df_modality2, is_train=True + ): + super().__init__() + self.df_modality1 = df_modality1 + self.df_modality2 = df_modality2 + self.is_train = is_train + def __len__(self): + return self.df_modality1.shape[0] + + def __getitem__(self, index: int): + if self.is_train == True: + x = self.df_modality1.iloc[index].values + y = self.df_modality2.iloc[index].values + return x, y + else: + x = self.df_modality1.iloc[index].values + return x + +class Swish(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * sigmoid(i) + ctx.save_for_backward(i) + return result + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + +class Swish_module(nn.Module): + def forward(self, x): + return Swish.apply(x) + +sigmoid = torch.nn.Sigmoid() + +class ModelRegressionGex2Atac(nn.Module): + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionGex2Atac, self).__init__() + #self.bn = torch.nn.BatchNorm1d(1024) + self.input_ = nn.Linear(dim_mod1, 1024) + self.fc = nn.Linear(1024, 256) + self.fc1 = nn.Linear(256, 2048) + self.dropout1 = nn.Dropout(p=0.298885630228993) + self.dropout2 = nn.Dropout(p=0.11289717442776658) + self.dropout3 = nn.Dropout(p=0.13523634924414762) + self.output = nn.Linear(2048, dim_mod2) + def forward(self, x): + x = F.gelu(self.input_(x)) + x = self.dropout1(x) + x = F.gelu(self.fc(x)) + x = self.dropout2(x) + x = F.gelu(self.fc1(x)) + x = self.dropout3(x) + x = F.gelu(self.output(x)) + return x + +class ModelRegressionAtac2Gex(nn.Module): # + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionAtac2Gex, self).__init__() + self.input_ = nn.Linear(dim_mod1, 2048) + self.fc = nn.Linear(2048, 2048) + self.fc1 = nn.Linear(2048, 512) + self.dropout1 = nn.Dropout(p=0.2649138776004753) + self.dropout2 = nn.Dropout(p=0.1769628308148758) + self.dropout3 = nn.Dropout(p=0.2516791883012817) + self.output = nn.Linear(512, dim_mod2) + def forward(self, x): + x = F.gelu(self.input_(x)) + x = self.dropout1(x) + x = F.gelu(self.fc(x)) + x = self.dropout2(x) + x = F.gelu(self.fc1(x)) + x = self.dropout3(x) + x = F.gelu(self.output(x)) + return x + +class ModelRegressionAdt2Gex(nn.Module): + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionAdt2Gex, self).__init__() + self.input_ = nn.Linear(dim_mod1, 512) + self.dropout1 = nn.Dropout(p=0.0) + self.swish = Swish_module() + self.fc = nn.Linear(512, 512) + self.fc1 = nn.Linear(512, 512) + self.fc2 = nn.Linear(512, 512) + self.output = nn.Linear(512, dim_mod2) + def forward(self, x): + x = F.gelu(self.input_(x)) + x = F.gelu(self.fc(x)) + x = F.gelu(self.fc1(x)) + x = F.gelu(self.fc2(x)) + x = F.gelu(self.output(x)) + return x + +class ModelRegressionGex2Adt(nn.Module): + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionGex2Adt, self).__init__() + self.input_ = nn.Linear(dim_mod1, 512) + self.dropout1 = nn.Dropout(p=0.20335661386636347) + self.dropout2 = nn.Dropout(p=0.15395289261127876) + self.dropout3 = nn.Dropout(p=0.16902655078832815) + self.fc = nn.Linear(512, 512) + self.fc1 = nn.Linear(512, 2048) + self.output = nn.Linear(2048, dim_mod2) + def forward(self, x): + # x = self.batchswap_noise(x) + x = F.gelu(self.input_(x)) + x = self.dropout1(x) + x = F.gelu(self.fc(x)) + x = self.dropout2(x) + x = F.gelu(self.fc1(x)) + x = self.dropout3(x) + x = F.gelu(self.output(x)) + return x + +def rmse(y, y_pred): + return np.sqrt(np.mean(np.square(y - y_pred))) + +def train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, name_model, device): + best_score = 100000 + for i in range(100): + train_losses = [] + test_losses = [] + model.train() + + for x, y in dataloader_train: + optimizer.zero_grad() + output = model(x.float().to(device)) + loss = torch.sqrt(loss_fn(output, y.float().to(device))) + loss.backward() + train_losses.append(loss.item()) + optimizer.step() + + model.eval() + with torch.no_grad(): + for x, y in dataloader_test: + output = model(x.float().to(device)) + output[output<0] = 0.0 + loss = torch.sqrt(loss_fn(output, y.float().to(device))) + test_losses.append(loss.item()) + + outputs = [] + targets = [] + model.eval() + with torch.no_grad(): + for x, y in dataloader_test: + output = model(x.float().to(device)) + + outputs.append(output.detach().cpu().numpy()) + targets.append(y.float().detach().cpu().numpy()) + cat_outputs = np.concatenate(outputs) + cat_targets = np.concatenate(targets) + cat_outputs[cat_outputs<0.0] = 0 + + if best_score > rmse(cat_targets,cat_outputs): + torch.save(model.state_dict(), name_model) + best_score = rmse(cat_targets,cat_outputs) + print("best rmse: ", best_score) diff --git a/src/methods/novel/predict/config.vsh.yaml b/src/methods/novel/predict/config.vsh.yaml new file mode 100644 index 0000000..1074a81 --- /dev/null +++ b/src/methods/novel/predict/config.vsh.yaml @@ -0,0 +1,26 @@ +__merge__: ../../../api/comp_method_predict.yaml +name: novel_predict +arguments: + - name: "--input_transform" + type: file + direction: input + required: false + example: "lsi_transformer.pickle" +resources: + - type: python_script + path: script.py + - path: ../helper_functions.py +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + packages: + - scikit-learn + - networkx +runners: + - type: executable + - type: nextflow + directives: + label: [highmem, hightime, midcpu, highsharedmem, gpu] + diff --git a/src/methods/novel/predict/script.py b/src/methods/novel/predict/script.py new file mode 100644 index 0000000..240620c --- /dev/null +++ b/src/methods/novel/predict/script.py @@ -0,0 +1,118 @@ +import sys +import torch +from torch.utils.data import DataLoader + +import anndata as ad +import pickle +import numpy as np +from scipy.sparse import csc_matrix + +#check gpu available +if (torch.cuda.is_available()): + device = 'cuda:0' #switch to current device + print('current device: gpu', flush=True) +else: + device = 'cpu' + print('current device: cpu', flush=True) + + +## VIASH START + +par = { + 'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad', + 'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/test_mod1.h5ad', + 'input_model': 'resources_test/predict_modality/neurips2021_bmmc_cite/model.pt', + 'input_transform': 'transformer.pickle' +} +meta = { + 'resources_dir': 'src/tasks/predict_modality/methods/novel', + 'functionality_name': '171129' +} +## VIASH END + +sys.path.append(meta['resources_dir']) +from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac, ModalityMatchingDataset + +print("Load data", flush=True) + +input_test_mod1 = ad.read_h5ad(par['input_test_mod1']) +input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) + +mod1 = input_test_mod1.uns['modality'] +mod2 = input_train_mod2.uns['modality'] + +n_vars_mod1 = input_train_mod2.uns["model_dim"]["mod1"] +n_vars_mod2 = input_train_mod2.uns["model_dim"]["mod2"] + +input_test_mod1.X = input_test_mod1.layers['normalized'].tocsr() + +# Remove vars that were removed from training set. Mostlyy only applicable for testing. +if input_train_mod2.uns.get("removed_vars"): + rem_var = input_train_mod2.uns["removed_vars"] + input_test_mod1 = input_test_mod1[:, ~input_test_mod1.var_names.isin(rem_var)] + +del input_train_mod2 + + +model_fp = par['input_model'] + +print("Start predict", flush=True) + +if mod1 == 'GEX' and mod2 == 'ADT': + model = ModelRegressionGex2Adt(n_vars_mod1,n_vars_mod2) + weight = torch.load(model_fp, map_location='cpu') + with open(par['input_transform'], 'rb') as f: + lsi_transformer_gex = pickle.load(f) + + model.load_state_dict(weight) + input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1) + +elif mod1 == 'GEX' and mod2 == 'ATAC': + model = ModelRegressionGex2Atac(n_vars_mod1,n_vars_mod2) + weight = torch.load(model_fp, map_location='cpu') + with open(par['input_transform'], 'rb') as f: + lsi_transformer_gex = pickle.load(f) + + model.load_state_dict(weight) + input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1) + +elif mod1 == 'ATAC' and mod2 == 'GEX': + model = ModelRegressionAtac2Gex(n_vars_mod1,n_vars_mod2) + weight = torch.load(model_fp, map_location='cpu') + with open(par['input_transform'], 'rb') as f: + lsi_transformer_gex = pickle.load(f) + + model.load_state_dict(weight) + input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1) + +elif mod1 == 'ADT' and mod2 == 'GEX': + model = ModelRegressionAdt2Gex(n_vars_mod1,n_vars_mod2) + weight = torch.load(model_fp, map_location='cpu') + + model.load_state_dict(weight) + input_test_mod1_ = input_test_mod1.to_df() + +dataset_test = ModalityMatchingDataset(input_test_mod1_, None, is_train=False) +dataloader_test = DataLoader(dataset_test, 32, shuffle = False, num_workers = 4) + +outputs = [] +model.eval() +with torch.no_grad(): + for x in dataloader_test: + output = model(x.float()) + outputs.append(output.detach().cpu().numpy()) + +outputs = np.concatenate(outputs) +outputs[outputs<0] = 0 +outputs = csc_matrix(outputs) + +adata = ad.AnnData( + layers={"normalized": outputs}, + shape=outputs.shape, + uns={ + 'dataset_id': input_test_mod1.uns['dataset_id'], + 'method_id': meta['functionality_name'], + }, +) +adata.write_h5ad(par['output'], compression = "gzip") + diff --git a/src/methods/novel/run/config.vsh.yaml b/src/methods/novel/run/config.vsh.yaml new file mode 100644 index 0000000..605fb18 --- /dev/null +++ b/src/methods/novel/run/config.vsh.yaml @@ -0,0 +1,23 @@ +__merge__: ../../../api/comp_method.yaml +name: novel +label: Novel +summary: A method using encoder-decoder MLP model +description: This method trains an encoder-decoder MLP model with one output neuron per component in the target. As an input, the encoders use representations obtained from ATAC and GEX data via LSI transform and raw ADT data. The hyperparameters of the models were found via broad hyperparameter search using the Optuna framework. +references: + doi: + - 10.1101/2022.04.11.487796 +links: + documentation: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/novel#readme + repository: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/novel +info: + submission_id: "169769" + preferred_normalization: log_cp10k +resources: + - path: main.nf + type: nextflow_script + entrypoint: run_wf +dependencies: + - name: predict_modality/methods/novel_train + - name: predict_modality/methods/novel_predict +runners: + - type: nextflow \ No newline at end of file diff --git a/src/methods/novel/run/main.nf b/src/methods/novel/run/main.nf new file mode 100644 index 0000000..f3f879d --- /dev/null +++ b/src/methods/novel/run/main.nf @@ -0,0 +1,25 @@ +workflow run_wf { + take: input_ch + main: + output_ch = input_ch + | novel_train.run( + fromState: ["input_train_mod1", "input_train_mod2"], + toState: ["input_model": "output", "input_transform": "output_transform", "output_train_mod2": "output_train_mod2"] + ) + | novel_predict.run( + fromState: { id, state -> + [ + "input_train_mod2": state.output_train_mod2, + "input_test_mod1": state.input_test_mod1, + "input_model": state.input_model, + "input_transform": state.input_transform, + "output": state.output]}, + toState: ["output": "output"] + ) + + | map { tup -> + [tup[0], [output: tup[1].output]] + } + + emit: output_ch +} diff --git a/src/methods/novel/train/config.vsh.yaml b/src/methods/novel/train/config.vsh.yaml new file mode 100644 index 0000000..12717b6 --- /dev/null +++ b/src/methods/novel/train/config.vsh.yaml @@ -0,0 +1,32 @@ +__merge__: ../../../api/comp_method_train.yaml +name: novel_train +arguments: + - name: --output_transform + type: file + description: "The output transform file" + required: false + default: "lsi_transformer.pickle" + direction: output + - name: --output_train_mod2 + type: file + description: copy of the input with model dim in `.uns` + direction: output + default: "train_mod2.h5ad" + required: false +resources: + - path: script.py + type: python_script + - path: ../helper_functions.py +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + packages: + - scikit-learn + - networkx +runners: + - type: executable + - type: nextflow + directives: + label: [highmem, hightime, midcpu, highsharedmem, gpu] diff --git a/src/methods/novel/train/script.py b/src/methods/novel/train/script.py new file mode 100644 index 0000000..65edfe6 --- /dev/null +++ b/src/methods/novel/train/script.py @@ -0,0 +1,147 @@ +import sys + +import torch +from torch.utils.data import DataLoader +# from sklearn.model_selection import train_test_split + +import anndata as ad +import pickle + +#check gpu available +if (torch.cuda.is_available()): + device = 'cuda:0' #switch to current device + print('current device: gpu', flush=True) +else: + device = 'cpu' + print('current device: cpu', flush=True) + + +## VIASH START + +par = { + 'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod1.h5ad', + 'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad', + 'output_train_mod2': 'train_mod2.h5ad', + 'output': 'model.pt' +} + +meta = { + 'resources_dir': 'src/tasks/predict_modality/methods/novel', +} +## VIASH END + + +sys.path.append(meta['resources_dir']) +from helper_functions import train_and_valid, lsiTransformer, ModalityMatchingDataset +from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac + +print('Load data', flush=True) + +input_train_mod1 = ad.read_h5ad(par['input_train_mod1']) +input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) + +adata = input_train_mod2.copy() + +mod1 = input_train_mod1.uns['modality'] +mod2 = input_train_mod2.uns['modality'] + +input_train_mod1.X = input_train_mod1.layers['normalized'] +input_train_mod2.X = input_train_mod2.layers['normalized'] + +input_train_mod2_df = input_train_mod2.to_df() + +del input_train_mod2 + +print('Start train', flush=True) + + +# Check for zero divide +zero_row = input_train_mod1.X.sum(axis=0) == 0 + +rem_var = None +if True in zero_row: + rem_var = input_train_mod1[:, zero_row].var_names + input_train_mod1 = input_train_mod1[:, ~zero_row] + + +# select number of variables for LSI +n_comp = input_train_mod1.n_vars -1 if input_train_mod1.n_vars < 256 else 256 + +if mod1 != 'ADT': + lsi_transformer_gex = lsiTransformer(n_components=n_comp) + input_train_mod1_df = lsi_transformer_gex.fit_transform(input_train_mod1) +else: + input_train_mod1_df = input_train_mod1.to_df() + +# reproduce train/test split from phase 1 +batch = input_train_mod1.obs["batch"] +train_ix = [ k for k,v in enumerate(batch) if v not in {'s1d2', 's3d7'} ] +test_ix = [ k for k,v in enumerate(batch) if v in {'s1d2', 's3d7'} ] + +train_mod1 = input_train_mod1_df.iloc[train_ix, :] +train_mod2 = input_train_mod2_df.iloc[train_ix, :] +test_mod1 = input_train_mod1_df.iloc[test_ix, :] +test_mod2 = input_train_mod2_df.iloc[test_ix, :] + +n_vars_train_mod1 = train_mod1.shape[1] +n_vars_train_mod2 = train_mod2.shape[1] +n_vars_test_mod1 = test_mod1.shape[1] +n_vars_test_mod2 = test_mod2.shape[1] + +n_vars_mod1 = input_train_mod1_df.shape[1] +n_vars_mod2 = input_train_mod2_df.shape[1] + +if mod1 == 'ATAC' and mod2 == 'GEX': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 256, shuffle = True, num_workers = 8) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 64, shuffle = False, num_workers = 8) + + model = ModelRegressionAtac2Gex(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.00008386597445284492,weight_decay=0.000684887347727808) + +elif mod1 == 'ADT' and mod2 == 'GEX': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 64, shuffle = True, num_workers = 4) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 32, shuffle = False, num_workers = 4) + + model = ModelRegressionAdt2Gex(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.00041, weight_decay=0.0000139) + + +elif mod1 == 'GEX' and mod2 == 'ADT': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 32, shuffle = True, num_workers = 8) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 64, shuffle = False, num_workers = 8) + + model = ModelRegressionGex2Adt(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.000034609210829678734, weight_decay=0.0009965881574697426) + + +elif mod1 == 'GEX' and mod2 == 'ATAC': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 64, shuffle = True, num_workers = 8) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 64, shuffle = False, num_workers = 8) + + model = ModelRegressionGex2Atac(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001806762345275399, weight_decay=0.0004084171379280058) + +loss_fn = torch.nn.MSELoss() +train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, par['output'], device) + +# Add model dim for use in predict part +adata.uns["model_dim"] = {"mod1": n_vars_mod1, "mod2": n_vars_mod2} +if rem_var: + adata.uns["removed_vars"] = [rem_var[0]] +adata.write_h5ad(par['output_train_mod2'], compression="gzip") + +if mod1 != 'ADT': + with open(par['output_transform'], 'wb') as f: + pickle.dump(lsi_transformer_gex, f) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 1a1715e..acef848 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -73,6 +73,7 @@ dependencies: - name: methods/lm - name: methods/lmds_irlba_rf - name: methods/guanlab_dengkw_pm + - name: methods/novel - name: metrics/correlation - name: metrics/mse runners: diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 6b92dae..ad27341 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -21,7 +21,8 @@ workflow run_wf { knnr_r, lm, lmds_irlba_rf, - guanlab_dengkw_pm + guanlab_dengkw_pm, + novel ] // construct list of metrics From f16f2b412f182dfe2f0f3ab7a594a4318836ca81 Mon Sep 17 00:00:00 2001 From: Robrecht Cannoodt Date: Mon, 26 Aug 2024 23:52:46 +0200 Subject: [PATCH 2/3] fix config --- src/methods/novel/run/config.vsh.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/methods/novel/run/config.vsh.yaml b/src/methods/novel/run/config.vsh.yaml index 605fb18..a30a31e 100644 --- a/src/methods/novel/run/config.vsh.yaml +++ b/src/methods/novel/run/config.vsh.yaml @@ -17,7 +17,7 @@ resources: type: nextflow_script entrypoint: run_wf dependencies: - - name: predict_modality/methods/novel_train - - name: predict_modality/methods/novel_predict + - name: methods/novel_train + - name: methods/novel_predict runners: - type: nextflow \ No newline at end of file From 1baca754092486bd893c84b5c9f858909cd08284 Mon Sep 17 00:00:00 2001 From: Robrecht Cannoodt Date: Wed, 8 Jan 2025 17:27:13 +0100 Subject: [PATCH 3/3] update novel --- scripts/create_datasets/test_resources.sh | 15 +++++++- src/methods/novel/predict/config.vsh.yaml | 12 +++--- src/methods/novel/predict/script.py | 29 +++++++-------- src/methods/novel/run/main.nf | 10 +---- src/methods/novel/train/config.vsh.yaml | 13 ------- src/methods/novel/train/script.py | 45 ++++++++++++++--------- 6 files changed, 63 insertions(+), 61 deletions(-) diff --git a/scripts/create_datasets/test_resources.sh b/scripts/create_datasets/test_resources.sh index d869d00..7c00b7f 100755 --- a/scripts/create_datasets/test_resources.sh +++ b/scripts/create_datasets/test_resources.sh @@ -30,20 +30,31 @@ nextflow run . \ echo "Run one method" for name in bmmc_cite/normal bmmc_cite/swap bmmc_multiome/normal bmmc_multiome/swap; do + echo "Run KNN on $name" viash run src/methods/knnr_py/config.vsh.yaml -- \ --input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \ --input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \ --input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \ --output $OUTPUT_DIR/openproblems_neurips2021/$name/prediction.h5ad - # pre-train simple_mlp - rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ + echo "pre-train simple_mlp on $name" + [ -d $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ ] && rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ mkdir -p $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ viash run src/methods/simple_mlp/train/config.vsh.yaml -- \ --input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \ --input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \ --input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \ --output $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ + + echo "pre-train novel on $name" + [ -d $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ ] && rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ + mkdir -p $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ + viash run src/methods/novel/train/config.vsh.yaml -- \ + --input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \ + --input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \ + --input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \ + --output $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel + done # only run this if you have access to the openproblems-data bucket diff --git a/src/methods/novel/predict/config.vsh.yaml b/src/methods/novel/predict/config.vsh.yaml index 1074a81..d93f42f 100644 --- a/src/methods/novel/predict/config.vsh.yaml +++ b/src/methods/novel/predict/config.vsh.yaml @@ -1,11 +1,11 @@ __merge__: ../../../api/comp_method_predict.yaml name: novel_predict -arguments: - - name: "--input_transform" - type: file - direction: input - required: false - example: "lsi_transformer.pickle" + +info: + test_setup: + with_model: + input_model: resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/models/novel + resources: - type: python_script path: script.py diff --git a/src/methods/novel/predict/script.py b/src/methods/novel/predict/script.py index 240620c..7ace272 100644 --- a/src/methods/novel/predict/script.py +++ b/src/methods/novel/predict/script.py @@ -17,26 +17,27 @@ ## VIASH START - par = { 'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad', 'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/test_mod1.h5ad', 'input_model': 'resources_test/predict_modality/neurips2021_bmmc_cite/model.pt', - 'input_transform': 'transformer.pickle' } meta = { 'resources_dir': 'src/tasks/predict_modality/methods/novel', - 'functionality_name': '171129' } ## VIASH END sys.path.append(meta['resources_dir']) from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac, ModalityMatchingDataset +input_model = f"{par['input_model']}/tensor.pt" +input_transform = f"{par['input_model']}/transform.pkl" +input_h5ad = f"{par['input_model']}/train_mod2.h5ad" + print("Load data", flush=True) input_test_mod1 = ad.read_h5ad(par['input_test_mod1']) -input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) +input_train_mod2 = ad.read_h5ad(input_h5ad) mod1 = input_test_mod1.uns['modality'] mod2 = input_train_mod2.uns['modality'] @@ -46,7 +47,7 @@ input_test_mod1.X = input_test_mod1.layers['normalized'].tocsr() -# Remove vars that were removed from training set. Mostlyy only applicable for testing. +# Remove vars that were removed from training set. Mostly only applicable for testing. if input_train_mod2.uns.get("removed_vars"): rem_var = input_train_mod2.uns["removed_vars"] input_test_mod1 = input_test_mod1[:, ~input_test_mod1.var_names.isin(rem_var)] @@ -54,14 +55,12 @@ del input_train_mod2 -model_fp = par['input_model'] - print("Start predict", flush=True) if mod1 == 'GEX' and mod2 == 'ADT': model = ModelRegressionGex2Adt(n_vars_mod1,n_vars_mod2) - weight = torch.load(model_fp, map_location='cpu') - with open(par['input_transform'], 'rb') as f: + weight = torch.load(input_model, map_location='cpu') + with open(input_transform, 'rb') as f: lsi_transformer_gex = pickle.load(f) model.load_state_dict(weight) @@ -69,8 +68,8 @@ elif mod1 == 'GEX' and mod2 == 'ATAC': model = ModelRegressionGex2Atac(n_vars_mod1,n_vars_mod2) - weight = torch.load(model_fp, map_location='cpu') - with open(par['input_transform'], 'rb') as f: + weight = torch.load(input_model, map_location='cpu') + with open(input_transform, 'rb') as f: lsi_transformer_gex = pickle.load(f) model.load_state_dict(weight) @@ -78,8 +77,8 @@ elif mod1 == 'ATAC' and mod2 == 'GEX': model = ModelRegressionAtac2Gex(n_vars_mod1,n_vars_mod2) - weight = torch.load(model_fp, map_location='cpu') - with open(par['input_transform'], 'rb') as f: + weight = torch.load(input_model, map_location='cpu') + with open(input_transform, 'rb') as f: lsi_transformer_gex = pickle.load(f) model.load_state_dict(weight) @@ -87,7 +86,7 @@ elif mod1 == 'ADT' and mod2 == 'GEX': model = ModelRegressionAdt2Gex(n_vars_mod1,n_vars_mod2) - weight = torch.load(model_fp, map_location='cpu') + weight = torch.load(input_model, map_location='cpu') model.load_state_dict(weight) input_test_mod1_ = input_test_mod1.to_df() @@ -111,7 +110,7 @@ shape=outputs.shape, uns={ 'dataset_id': input_test_mod1.uns['dataset_id'], - 'method_id': meta['functionality_name'], + 'method_id': meta['name'], }, ) adata.write_h5ad(par['output'], compression = "gzip") diff --git a/src/methods/novel/run/main.nf b/src/methods/novel/run/main.nf index f3f879d..0d973af 100644 --- a/src/methods/novel/run/main.nf +++ b/src/methods/novel/run/main.nf @@ -4,16 +4,10 @@ workflow run_wf { output_ch = input_ch | novel_train.run( fromState: ["input_train_mod1", "input_train_mod2"], - toState: ["input_model": "output", "input_transform": "output_transform", "output_train_mod2": "output_train_mod2"] + toState: ["input_model": "output"] ) | novel_predict.run( - fromState: { id, state -> - [ - "input_train_mod2": state.output_train_mod2, - "input_test_mod1": state.input_test_mod1, - "input_model": state.input_model, - "input_transform": state.input_transform, - "output": state.output]}, + fromState: ["input_test_mod1", "input_train_mod2", "input_model"], toState: ["output": "output"] ) diff --git a/src/methods/novel/train/config.vsh.yaml b/src/methods/novel/train/config.vsh.yaml index 12717b6..794524a 100644 --- a/src/methods/novel/train/config.vsh.yaml +++ b/src/methods/novel/train/config.vsh.yaml @@ -1,18 +1,5 @@ __merge__: ../../../api/comp_method_train.yaml name: novel_train -arguments: - - name: --output_transform - type: file - description: "The output transform file" - required: false - default: "lsi_transformer.pickle" - direction: output - - name: --output_train_mod2 - type: file - description: copy of the input with model dim in `.uns` - direction: output - default: "train_mod2.h5ad" - required: false resources: - path: script.py type: python_script diff --git a/src/methods/novel/train/script.py b/src/methods/novel/train/script.py index 65edfe6..0cd9412 100644 --- a/src/methods/novel/train/script.py +++ b/src/methods/novel/train/script.py @@ -1,4 +1,7 @@ import sys +import os +import math +import numpy as np import torch from torch.utils.data import DataLoader @@ -17,26 +20,21 @@ ## VIASH START - par = { - 'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod1.h5ad', - 'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad', - 'output_train_mod2': 'train_mod2.h5ad', - 'output': 'model.pt' + 'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/train_mod1.h5ad', + 'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/train_mod2.h5ad', + 'output': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/models/novel' } - meta = { - 'resources_dir': 'src/tasks/predict_modality/methods/novel', + 'resources_dir': 'src/methods/novel', } ## VIASH END - sys.path.append(meta['resources_dir']) from helper_functions import train_and_valid, lsiTransformer, ModalityMatchingDataset from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac print('Load data', flush=True) - input_train_mod1 = ad.read_h5ad(par['input_train_mod1']) input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) @@ -53,8 +51,6 @@ del input_train_mod2 print('Start train', flush=True) - - # Check for zero divide zero_row = input_train_mod1.X.sum(axis=0) == 0 @@ -75,8 +71,13 @@ # reproduce train/test split from phase 1 batch = input_train_mod1.obs["batch"] -train_ix = [ k for k,v in enumerate(batch) if v not in {'s1d2', 's3d7'} ] -test_ix = [ k for k,v in enumerate(batch) if v in {'s1d2', 's3d7'} ] +test_batches = {'s1d2', 's3d7'} +# if none of phase1_batch is in batch, sample 25% of batch categories rounded up +if len(test_batches.intersection(set(batch))) == 0: + all_batches = batch.cat.categories.tolist() + test_batches = set(np.random.choice(all_batches, math.ceil(len(all_batches) * 0.25), replace=False)) +train_ix = [ k for k,v in enumerate(batch) if v not in test_batches ] +test_ix = [ k for k,v in enumerate(batch) if v in test_batches ] train_mod1 = input_train_mod1_df.iloc[train_ix, :] train_mod2 = input_train_mod2_df.iloc[train_ix, :] @@ -134,14 +135,24 @@ optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001806762345275399, weight_decay=0.0004084171379280058) loss_fn = torch.nn.MSELoss() -train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, par['output'], device) + +# create dir for par['output'] +os.makedirs(par['output'], exist_ok=True) + +# determine filenames +output_model = f"{par['output']}/tensor.pt" +output_h5ad = f"{par['output']}/train_mod2.h5ad" +output_transform = f"{par['output']}/transform.pkl" + +# train model +train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, output_model, device) # Add model dim for use in predict part adata.uns["model_dim"] = {"mod1": n_vars_mod1, "mod2": n_vars_mod2} -if rem_var: +if rem_var is not None: adata.uns["removed_vars"] = [rem_var[0]] -adata.write_h5ad(par['output_train_mod2'], compression="gzip") +adata.write_h5ad(output_h5ad, compression="gzip") if mod1 != 'ADT': - with open(par['output_transform'], 'wb') as f: + with open(output_transform, 'wb') as f: pickle.dump(lsi_transformer_gex, f)