diff --git a/scripts/create_datasets/test_resources.sh b/scripts/create_datasets/test_resources.sh index d869d00..7c00b7f 100755 --- a/scripts/create_datasets/test_resources.sh +++ b/scripts/create_datasets/test_resources.sh @@ -30,20 +30,31 @@ nextflow run . \ echo "Run one method" for name in bmmc_cite/normal bmmc_cite/swap bmmc_multiome/normal bmmc_multiome/swap; do + echo "Run KNN on $name" viash run src/methods/knnr_py/config.vsh.yaml -- \ --input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \ --input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \ --input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \ --output $OUTPUT_DIR/openproblems_neurips2021/$name/prediction.h5ad - # pre-train simple_mlp - rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ + echo "pre-train simple_mlp on $name" + [ -d $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ ] && rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ mkdir -p $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ viash run src/methods/simple_mlp/train/config.vsh.yaml -- \ --input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \ --input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \ --input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \ --output $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ + + echo "pre-train novel on $name" + [ -d $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ ] && rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ + mkdir -p $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ + viash run src/methods/novel/train/config.vsh.yaml -- \ + --input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \ + --input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \ + --input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \ + --output $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel + done # only run this if you have access to the openproblems-data bucket diff --git a/src/methods/novel/helper_functions.py b/src/methods/novel/helper_functions.py new file mode 100644 index 0000000..2696c2f --- /dev/null +++ b/src/methods/novel/helper_functions.py @@ -0,0 +1,246 @@ +import torch + +from torch import nn +import torch.nn.functional as F + +from torch.utils.data import Dataset + +from typing import Optional + +import anndata +import numpy as np +import pandas as pd +import scipy.sparse +import sklearn.decomposition +import sklearn.feature_extraction.text +import sklearn.preprocessing +import sklearn.neighbors +import sklearn.utils.extmath + +class tfidfTransformer(): + def __init__(self): + self.idf = None + self.fitted = False + + def fit(self, X): + self.idf = X.shape[0] / X.sum(axis=0) + self.fitted = True + + def transform(self, X): + if not self.fitted: + raise RuntimeError('Transformer was not fitted on any data') + if scipy.sparse.issparse(X): + tf = X.multiply(1 / X.sum(axis=1)) + return tf.multiply(self.idf) + else: + tf = X / X.sum(axis=1, keepdims=True) + return tf * self.idf + + def fit_transform(self, X): + self.fit(X) + return self.transform(X) + +class lsiTransformer(): + def __init__(self, + n_components: int = 20, + use_highly_variable = None + ): + self.n_components = n_components + self.use_highly_variable = use_highly_variable + self.tfidfTransformer = tfidfTransformer() + self.normalizer = sklearn.preprocessing.Normalizer(norm="l1") + self.pcaTransformer = sklearn.decomposition.TruncatedSVD(n_components = self.n_components, random_state=777) + # self.lsi_mean = None + # self.lsi_std = None + self.fitted = None + + def fit(self, adata: anndata.AnnData): + if self.use_highly_variable is None: + self.use_highly_variable = "hvg" in adata.var + adata_use = adata[:, adata.var["hvg"]] if self.use_highly_variable else adata + X = self.tfidfTransformer.fit_transform(adata_use.X) + X_norm = self.normalizer.fit_transform(X) + X_norm = np.log1p(X_norm * 1e4) + X_lsi = self.pcaTransformer.fit_transform(X_norm) + # self.lsi_mean = X_lsi.mean(axis=1, keepdims=True) + # self.lsi_std = X_lsi.std(axis=1, ddof=1, keepdims=True) + self.fitted = True + + def transform(self, adata): + if not self.fitted: + raise RuntimeError('Transformer was not fitted on any data') + adata_use = adata[:, adata.var["hvg"]] if self.use_highly_variable else adata + X = self.tfidfTransformer.transform(adata_use.X) + X_norm = self.normalizer.transform(X) + X_norm = np.log1p(X_norm * 1e4) + X_lsi = self.pcaTransformer.transform(X_norm) + X_lsi -= X_lsi.mean(axis=1, keepdims=True) + X_lsi /= X_lsi.std(axis=1, ddof=1, keepdims=True) + lsi_df = pd.DataFrame(X_lsi, index = adata_use.obs_names) + return lsi_df + + def fit_transform(self, adata): + self.fit(adata) + return self.transform(adata) + +class ModalityMatchingDataset(Dataset): + def __init__( + self, df_modality1, df_modality2, is_train=True + ): + super().__init__() + self.df_modality1 = df_modality1 + self.df_modality2 = df_modality2 + self.is_train = is_train + def __len__(self): + return self.df_modality1.shape[0] + + def __getitem__(self, index: int): + if self.is_train == True: + x = self.df_modality1.iloc[index].values + y = self.df_modality2.iloc[index].values + return x, y + else: + x = self.df_modality1.iloc[index].values + return x + +class Swish(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * sigmoid(i) + ctx.save_for_backward(i) + return result + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + +class Swish_module(nn.Module): + def forward(self, x): + return Swish.apply(x) + +sigmoid = torch.nn.Sigmoid() + +class ModelRegressionGex2Atac(nn.Module): + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionGex2Atac, self).__init__() + #self.bn = torch.nn.BatchNorm1d(1024) + self.input_ = nn.Linear(dim_mod1, 1024) + self.fc = nn.Linear(1024, 256) + self.fc1 = nn.Linear(256, 2048) + self.dropout1 = nn.Dropout(p=0.298885630228993) + self.dropout2 = nn.Dropout(p=0.11289717442776658) + self.dropout3 = nn.Dropout(p=0.13523634924414762) + self.output = nn.Linear(2048, dim_mod2) + def forward(self, x): + x = F.gelu(self.input_(x)) + x = self.dropout1(x) + x = F.gelu(self.fc(x)) + x = self.dropout2(x) + x = F.gelu(self.fc1(x)) + x = self.dropout3(x) + x = F.gelu(self.output(x)) + return x + +class ModelRegressionAtac2Gex(nn.Module): # + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionAtac2Gex, self).__init__() + self.input_ = nn.Linear(dim_mod1, 2048) + self.fc = nn.Linear(2048, 2048) + self.fc1 = nn.Linear(2048, 512) + self.dropout1 = nn.Dropout(p=0.2649138776004753) + self.dropout2 = nn.Dropout(p=0.1769628308148758) + self.dropout3 = nn.Dropout(p=0.2516791883012817) + self.output = nn.Linear(512, dim_mod2) + def forward(self, x): + x = F.gelu(self.input_(x)) + x = self.dropout1(x) + x = F.gelu(self.fc(x)) + x = self.dropout2(x) + x = F.gelu(self.fc1(x)) + x = self.dropout3(x) + x = F.gelu(self.output(x)) + return x + +class ModelRegressionAdt2Gex(nn.Module): + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionAdt2Gex, self).__init__() + self.input_ = nn.Linear(dim_mod1, 512) + self.dropout1 = nn.Dropout(p=0.0) + self.swish = Swish_module() + self.fc = nn.Linear(512, 512) + self.fc1 = nn.Linear(512, 512) + self.fc2 = nn.Linear(512, 512) + self.output = nn.Linear(512, dim_mod2) + def forward(self, x): + x = F.gelu(self.input_(x)) + x = F.gelu(self.fc(x)) + x = F.gelu(self.fc1(x)) + x = F.gelu(self.fc2(x)) + x = F.gelu(self.output(x)) + return x + +class ModelRegressionGex2Adt(nn.Module): + def __init__(self, dim_mod1, dim_mod2): + super(ModelRegressionGex2Adt, self).__init__() + self.input_ = nn.Linear(dim_mod1, 512) + self.dropout1 = nn.Dropout(p=0.20335661386636347) + self.dropout2 = nn.Dropout(p=0.15395289261127876) + self.dropout3 = nn.Dropout(p=0.16902655078832815) + self.fc = nn.Linear(512, 512) + self.fc1 = nn.Linear(512, 2048) + self.output = nn.Linear(2048, dim_mod2) + def forward(self, x): + # x = self.batchswap_noise(x) + x = F.gelu(self.input_(x)) + x = self.dropout1(x) + x = F.gelu(self.fc(x)) + x = self.dropout2(x) + x = F.gelu(self.fc1(x)) + x = self.dropout3(x) + x = F.gelu(self.output(x)) + return x + +def rmse(y, y_pred): + return np.sqrt(np.mean(np.square(y - y_pred))) + +def train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, name_model, device): + best_score = 100000 + for i in range(100): + train_losses = [] + test_losses = [] + model.train() + + for x, y in dataloader_train: + optimizer.zero_grad() + output = model(x.float().to(device)) + loss = torch.sqrt(loss_fn(output, y.float().to(device))) + loss.backward() + train_losses.append(loss.item()) + optimizer.step() + + model.eval() + with torch.no_grad(): + for x, y in dataloader_test: + output = model(x.float().to(device)) + output[output<0] = 0.0 + loss = torch.sqrt(loss_fn(output, y.float().to(device))) + test_losses.append(loss.item()) + + outputs = [] + targets = [] + model.eval() + with torch.no_grad(): + for x, y in dataloader_test: + output = model(x.float().to(device)) + + outputs.append(output.detach().cpu().numpy()) + targets.append(y.float().detach().cpu().numpy()) + cat_outputs = np.concatenate(outputs) + cat_targets = np.concatenate(targets) + cat_outputs[cat_outputs<0.0] = 0 + + if best_score > rmse(cat_targets,cat_outputs): + torch.save(model.state_dict(), name_model) + best_score = rmse(cat_targets,cat_outputs) + print("best rmse: ", best_score) diff --git a/src/methods/novel/predict/config.vsh.yaml b/src/methods/novel/predict/config.vsh.yaml new file mode 100644 index 0000000..d93f42f --- /dev/null +++ b/src/methods/novel/predict/config.vsh.yaml @@ -0,0 +1,26 @@ +__merge__: ../../../api/comp_method_predict.yaml +name: novel_predict + +info: + test_setup: + with_model: + input_model: resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/models/novel + +resources: + - type: python_script + path: script.py + - path: ../helper_functions.py +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + packages: + - scikit-learn + - networkx +runners: + - type: executable + - type: nextflow + directives: + label: [highmem, hightime, midcpu, highsharedmem, gpu] + diff --git a/src/methods/novel/predict/script.py b/src/methods/novel/predict/script.py new file mode 100644 index 0000000..7ace272 --- /dev/null +++ b/src/methods/novel/predict/script.py @@ -0,0 +1,117 @@ +import sys +import torch +from torch.utils.data import DataLoader + +import anndata as ad +import pickle +import numpy as np +from scipy.sparse import csc_matrix + +#check gpu available +if (torch.cuda.is_available()): + device = 'cuda:0' #switch to current device + print('current device: gpu', flush=True) +else: + device = 'cpu' + print('current device: cpu', flush=True) + + +## VIASH START +par = { + 'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad', + 'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/test_mod1.h5ad', + 'input_model': 'resources_test/predict_modality/neurips2021_bmmc_cite/model.pt', +} +meta = { + 'resources_dir': 'src/tasks/predict_modality/methods/novel', +} +## VIASH END + +sys.path.append(meta['resources_dir']) +from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac, ModalityMatchingDataset + +input_model = f"{par['input_model']}/tensor.pt" +input_transform = f"{par['input_model']}/transform.pkl" +input_h5ad = f"{par['input_model']}/train_mod2.h5ad" + +print("Load data", flush=True) + +input_test_mod1 = ad.read_h5ad(par['input_test_mod1']) +input_train_mod2 = ad.read_h5ad(input_h5ad) + +mod1 = input_test_mod1.uns['modality'] +mod2 = input_train_mod2.uns['modality'] + +n_vars_mod1 = input_train_mod2.uns["model_dim"]["mod1"] +n_vars_mod2 = input_train_mod2.uns["model_dim"]["mod2"] + +input_test_mod1.X = input_test_mod1.layers['normalized'].tocsr() + +# Remove vars that were removed from training set. Mostly only applicable for testing. +if input_train_mod2.uns.get("removed_vars"): + rem_var = input_train_mod2.uns["removed_vars"] + input_test_mod1 = input_test_mod1[:, ~input_test_mod1.var_names.isin(rem_var)] + +del input_train_mod2 + + +print("Start predict", flush=True) + +if mod1 == 'GEX' and mod2 == 'ADT': + model = ModelRegressionGex2Adt(n_vars_mod1,n_vars_mod2) + weight = torch.load(input_model, map_location='cpu') + with open(input_transform, 'rb') as f: + lsi_transformer_gex = pickle.load(f) + + model.load_state_dict(weight) + input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1) + +elif mod1 == 'GEX' and mod2 == 'ATAC': + model = ModelRegressionGex2Atac(n_vars_mod1,n_vars_mod2) + weight = torch.load(input_model, map_location='cpu') + with open(input_transform, 'rb') as f: + lsi_transformer_gex = pickle.load(f) + + model.load_state_dict(weight) + input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1) + +elif mod1 == 'ATAC' and mod2 == 'GEX': + model = ModelRegressionAtac2Gex(n_vars_mod1,n_vars_mod2) + weight = torch.load(input_model, map_location='cpu') + with open(input_transform, 'rb') as f: + lsi_transformer_gex = pickle.load(f) + + model.load_state_dict(weight) + input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1) + +elif mod1 == 'ADT' and mod2 == 'GEX': + model = ModelRegressionAdt2Gex(n_vars_mod1,n_vars_mod2) + weight = torch.load(input_model, map_location='cpu') + + model.load_state_dict(weight) + input_test_mod1_ = input_test_mod1.to_df() + +dataset_test = ModalityMatchingDataset(input_test_mod1_, None, is_train=False) +dataloader_test = DataLoader(dataset_test, 32, shuffle = False, num_workers = 4) + +outputs = [] +model.eval() +with torch.no_grad(): + for x in dataloader_test: + output = model(x.float()) + outputs.append(output.detach().cpu().numpy()) + +outputs = np.concatenate(outputs) +outputs[outputs<0] = 0 +outputs = csc_matrix(outputs) + +adata = ad.AnnData( + layers={"normalized": outputs}, + shape=outputs.shape, + uns={ + 'dataset_id': input_test_mod1.uns['dataset_id'], + 'method_id': meta['name'], + }, +) +adata.write_h5ad(par['output'], compression = "gzip") + diff --git a/src/methods/novel/run/config.vsh.yaml b/src/methods/novel/run/config.vsh.yaml new file mode 100644 index 0000000..a30a31e --- /dev/null +++ b/src/methods/novel/run/config.vsh.yaml @@ -0,0 +1,23 @@ +__merge__: ../../../api/comp_method.yaml +name: novel +label: Novel +summary: A method using encoder-decoder MLP model +description: This method trains an encoder-decoder MLP model with one output neuron per component in the target. As an input, the encoders use representations obtained from ATAC and GEX data via LSI transform and raw ADT data. The hyperparameters of the models were found via broad hyperparameter search using the Optuna framework. +references: + doi: + - 10.1101/2022.04.11.487796 +links: + documentation: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/novel#readme + repository: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/novel +info: + submission_id: "169769" + preferred_normalization: log_cp10k +resources: + - path: main.nf + type: nextflow_script + entrypoint: run_wf +dependencies: + - name: methods/novel_train + - name: methods/novel_predict +runners: + - type: nextflow \ No newline at end of file diff --git a/src/methods/novel/run/main.nf b/src/methods/novel/run/main.nf new file mode 100644 index 0000000..0d973af --- /dev/null +++ b/src/methods/novel/run/main.nf @@ -0,0 +1,19 @@ +workflow run_wf { + take: input_ch + main: + output_ch = input_ch + | novel_train.run( + fromState: ["input_train_mod1", "input_train_mod2"], + toState: ["input_model": "output"] + ) + | novel_predict.run( + fromState: ["input_test_mod1", "input_train_mod2", "input_model"], + toState: ["output": "output"] + ) + + | map { tup -> + [tup[0], [output: tup[1].output]] + } + + emit: output_ch +} diff --git a/src/methods/novel/train/config.vsh.yaml b/src/methods/novel/train/config.vsh.yaml new file mode 100644 index 0000000..794524a --- /dev/null +++ b/src/methods/novel/train/config.vsh.yaml @@ -0,0 +1,19 @@ +__merge__: ../../../api/comp_method_train.yaml +name: novel_train +resources: + - path: script.py + type: python_script + - path: ../helper_functions.py +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + packages: + - scikit-learn + - networkx +runners: + - type: executable + - type: nextflow + directives: + label: [highmem, hightime, midcpu, highsharedmem, gpu] diff --git a/src/methods/novel/train/script.py b/src/methods/novel/train/script.py new file mode 100644 index 0000000..0cd9412 --- /dev/null +++ b/src/methods/novel/train/script.py @@ -0,0 +1,158 @@ +import sys +import os +import math +import numpy as np + +import torch +from torch.utils.data import DataLoader +# from sklearn.model_selection import train_test_split + +import anndata as ad +import pickle + +#check gpu available +if (torch.cuda.is_available()): + device = 'cuda:0' #switch to current device + print('current device: gpu', flush=True) +else: + device = 'cpu' + print('current device: cpu', flush=True) + + +## VIASH START +par = { + 'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/train_mod1.h5ad', + 'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/train_mod2.h5ad', + 'output': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/models/novel' +} +meta = { + 'resources_dir': 'src/methods/novel', +} +## VIASH END + +sys.path.append(meta['resources_dir']) +from helper_functions import train_and_valid, lsiTransformer, ModalityMatchingDataset +from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac + +print('Load data', flush=True) +input_train_mod1 = ad.read_h5ad(par['input_train_mod1']) +input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) + +adata = input_train_mod2.copy() + +mod1 = input_train_mod1.uns['modality'] +mod2 = input_train_mod2.uns['modality'] + +input_train_mod1.X = input_train_mod1.layers['normalized'] +input_train_mod2.X = input_train_mod2.layers['normalized'] + +input_train_mod2_df = input_train_mod2.to_df() + +del input_train_mod2 + +print('Start train', flush=True) +# Check for zero divide +zero_row = input_train_mod1.X.sum(axis=0) == 0 + +rem_var = None +if True in zero_row: + rem_var = input_train_mod1[:, zero_row].var_names + input_train_mod1 = input_train_mod1[:, ~zero_row] + + +# select number of variables for LSI +n_comp = input_train_mod1.n_vars -1 if input_train_mod1.n_vars < 256 else 256 + +if mod1 != 'ADT': + lsi_transformer_gex = lsiTransformer(n_components=n_comp) + input_train_mod1_df = lsi_transformer_gex.fit_transform(input_train_mod1) +else: + input_train_mod1_df = input_train_mod1.to_df() + +# reproduce train/test split from phase 1 +batch = input_train_mod1.obs["batch"] +test_batches = {'s1d2', 's3d7'} +# if none of phase1_batch is in batch, sample 25% of batch categories rounded up +if len(test_batches.intersection(set(batch))) == 0: + all_batches = batch.cat.categories.tolist() + test_batches = set(np.random.choice(all_batches, math.ceil(len(all_batches) * 0.25), replace=False)) +train_ix = [ k for k,v in enumerate(batch) if v not in test_batches ] +test_ix = [ k for k,v in enumerate(batch) if v in test_batches ] + +train_mod1 = input_train_mod1_df.iloc[train_ix, :] +train_mod2 = input_train_mod2_df.iloc[train_ix, :] +test_mod1 = input_train_mod1_df.iloc[test_ix, :] +test_mod2 = input_train_mod2_df.iloc[test_ix, :] + +n_vars_train_mod1 = train_mod1.shape[1] +n_vars_train_mod2 = train_mod2.shape[1] +n_vars_test_mod1 = test_mod1.shape[1] +n_vars_test_mod2 = test_mod2.shape[1] + +n_vars_mod1 = input_train_mod1_df.shape[1] +n_vars_mod2 = input_train_mod2_df.shape[1] + +if mod1 == 'ATAC' and mod2 == 'GEX': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 256, shuffle = True, num_workers = 8) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 64, shuffle = False, num_workers = 8) + + model = ModelRegressionAtac2Gex(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.00008386597445284492,weight_decay=0.000684887347727808) + +elif mod1 == 'ADT' and mod2 == 'GEX': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 64, shuffle = True, num_workers = 4) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 32, shuffle = False, num_workers = 4) + + model = ModelRegressionAdt2Gex(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.00041, weight_decay=0.0000139) + + +elif mod1 == 'GEX' and mod2 == 'ADT': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 32, shuffle = True, num_workers = 8) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 64, shuffle = False, num_workers = 8) + + model = ModelRegressionGex2Adt(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.000034609210829678734, weight_decay=0.0009965881574697426) + + +elif mod1 == 'GEX' and mod2 == 'ATAC': + dataset_train = ModalityMatchingDataset(train_mod1, train_mod2) + dataloader_train = DataLoader(dataset_train, 64, shuffle = True, num_workers = 8) + + dataset_test = ModalityMatchingDataset(test_mod1, test_mod2) + dataloader_test = DataLoader(dataset_test, 64, shuffle = False, num_workers = 8) + + model = ModelRegressionGex2Atac(n_vars_mod1,n_vars_mod2).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001806762345275399, weight_decay=0.0004084171379280058) + +loss_fn = torch.nn.MSELoss() + +# create dir for par['output'] +os.makedirs(par['output'], exist_ok=True) + +# determine filenames +output_model = f"{par['output']}/tensor.pt" +output_h5ad = f"{par['output']}/train_mod2.h5ad" +output_transform = f"{par['output']}/transform.pkl" + +# train model +train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, output_model, device) + +# Add model dim for use in predict part +adata.uns["model_dim"] = {"mod1": n_vars_mod1, "mod2": n_vars_mod2} +if rem_var is not None: + adata.uns["removed_vars"] = [rem_var[0]] +adata.write_h5ad(output_h5ad, compression="gzip") + +if mod1 != 'ADT': + with open(output_transform, 'wb') as f: + pickle.dump(lsi_transformer_gex, f) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 5d01dd4..75a7429 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -70,6 +70,7 @@ dependencies: - name: methods/knnr_r - name: methods/lm - name: methods/guanlab_dengkw_pm + - name: methods/novel - name: methods/simple_mlp - name: metrics/correlation - name: metrics/mse diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index c279382..a032696 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -15,6 +15,7 @@ methods = [ knnr_r, lm, guanlab_dengkw_pm, + novel, simple_mlp ]