Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 131 additions & 147 deletions src/scdori/_core/train_grn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,40 +129,41 @@ def get_tf_expression(
torch.Tensor
A (num_topics x num_tfs) tensor of TF expression values for each topic.
"""
if tf_expression_mode == "True":
if tf_expression_mode == "True": # FIXME
latent_all_torch = get_latent_topics(
model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot
)
top_k_indices = np.argsort(latent_all_torch, axis=0)[-config_file.cells_per_topic :]
rna_tf_vals = rna_anndata.X[:, tf_indices]

if sp.issparse(rna_tf_vals):
rna_tf_vals = rna_tf_vals.todense()
rna_tf_vals = np.array(rna_tf_vals)
rna_tf_vals = rna_tf_vals.toarray()
else:
rna_tf_vals = np.asarray(rna_tf_vals)

median_cell = np.median(rna_tf_vals.sum(axis=1))
rna_tf_vals = median_cell * (rna_tf_vals / rna_tf_vals.sum(axis=1, keepdims=True))
topic_tf = np.array([rna_tf_vals[top_k_indices[:, t], :].mean(axis=0) for t in range(model.num_topics)])
topic_tf = torch.from_numpy(topic_tf)

preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True)
preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True)
topic_tf = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9)
topic_tf[topic_tf < config_file.tf_expression_clamp] = 0
topic_tf = topic_tf.to(device)
return topic_tf
preds_tf_denoised_min = topic_tf.min(dim=1, keepdim=True)[0]
preds_tf_denoised_max = topic_tf.max(dim=1, keepdim=True)[0]
normalized_tf = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9)
topic_tf = normalized_tf.clamp(min=config_file.tf_expression_clamp)
return topic_tf.to(device)
else:
import torch.nn as nn # Ensure this import is available if using nn.Softmax
import torch.nn as nn # Ensure this import is available if using nn.Softmax # FIXME

topic_tf = nn.Softmax(dim=1)(model.decoder.topic_tf_decoder.detach().cpu())

preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True)
preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True)
preds_tf_denoised_min = topic_tf.min(dim=1, keepdim=True)[0]
preds_tf_denoised_max = topic_tf.max(dim=1, keepdim=True)[0]
tf_normalised = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9)
tf_normalised[tf_normalised < config_file.tf_expression_clamp] = 0
topic_tf = tf_normalised.to(device)
return topic_tf
tf_normalised = tf_normalised.clamp(min=config_file.tf_expression_clamp)
return tf_normalised.to(device)


@torch.no_grad()
def compute_eval_loss_grn(
model,
device,
Expand Down Expand Up @@ -211,12 +212,7 @@ def compute_eval_loss_grn(
(eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn).
"""
model.eval()
running_loss = 0.0
running_loss_atac = 0.0
running_loss_tf = 0.0
running_loss_rna = 0.0
running_loss_rna_grn = 0.0
nbatch = 0
running_stats = {"loss": 0.0, "loss_atac": 0.0, "loss_tf": 0.0, "loss_rna": 0.0, "loss_rna_grn": 0.0, "count": 0}

topic_tf_input = get_tf_expression(
config_file.tf_expression_mode,
Expand All @@ -230,94 +226,95 @@ def compute_eval_loss_grn(
encoding_batch_onehot,
config_file,
)

criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum")
alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1)
alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1)

with torch.no_grad():
for batch_data in eval_loader:
cell_indices = batch_data[0].to(device)
B = cell_indices.shape[0]

input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch(
device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot
)
rna_input = input_matrix[:, : model.num_genes]
atac_input = input_matrix[:, model.num_genes :]
log_lib_rna = library_size_value[:, 0].reshape(-1, 1)
log_lib_atac = library_size_value[:, 1].reshape(-1, 1)

out = model(
rna_input,
atac_input,
tf_exp,
topic_tf_input,
log_lib_rna,
log_lib_atac,
num_cells_value,
input_batch,
phase="grn",
)
preds_atac = out["preds_atac"]
mu_nb_tf = out["mu_nb_tf"]
mu_nb_rna = out["mu_nb_rna"]
mu_nb_rna_grn = out["mu_nb_rna_grn"]

criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum")
library_factor_peak = torch.exp(log_lib_atac.view(B, 1))
preds_poisson = preds_atac * library_factor_peak
loss_atac = criterion_poisson(preds_poisson, atac_input)

alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1)
nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean()
loss_tf = -nb_tf_ll

alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1)
nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean()
loss_rna = -nb_rna_ll

nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean()
loss_rna_grn = -nb_rna_grn_ll

l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1)
l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2)
l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1)
l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2)
l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1)
l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2)
l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1)
l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1)

loss_norm = (
config_file.l1_penalty_topic_tf * l1_norm_tf
+ config_file.l2_penalty_topic_tf * l2_norm_tf
+ config_file.l1_penalty_topic_peak * l1_norm_peak
+ config_file.l2_penalty_topic_peak * l2_norm_peak
+ config_file.l1_penalty_gene_peak * l1_norm_gene_peak
+ config_file.l2_penalty_gene_peak * l2_norm_gene_peak
+ config_file.l1_penalty_grn_activator * l1_norm_grn_activator
+ config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor
)
for batch_data in eval_loader:
cell_indices = batch_data[0].to(device)
B = cell_indices.shape[0]

total_loss = (
config_file.weight_atac_grn * loss_atac
+ config_file.weight_tf_grn * loss_tf
+ config_file.weight_rna_grn * loss_rna
+ config_file.weight_rna_from_grn * loss_rna_grn
+ loss_norm
)

running_loss += total_loss.item()
running_loss_atac += loss_atac.item()
running_loss_tf += loss_tf.item()
running_loss_rna += loss_rna.item()
running_loss_rna_grn += loss_rna_grn.item()
nbatch += 1
input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch(
device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot
)
rna_input = input_matrix[:, : model.num_genes]
atac_input = input_matrix[:, model.num_genes :]
log_lib_rna = library_size_value[:, 0].reshape(-1, 1)
log_lib_atac = library_size_value[:, 1].reshape(-1, 1)

out = model(
rna_input,
atac_input,
tf_exp,
topic_tf_input,
log_lib_rna,
log_lib_atac,
num_cells_value,
input_batch,
phase="grn",
)
preds_atac = out["preds_atac"]
mu_nb_tf = out["mu_nb_tf"]
mu_nb_rna = out["mu_nb_rna"]
mu_nb_rna_grn = out["mu_nb_rna_grn"]

library_factor_peak = torch.exp(log_lib_atac.view(B, 1))
preds_poisson = preds_atac * library_factor_peak
loss_atac = criterion_poisson(preds_poisson, atac_input)

nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean()
loss_tf = -nb_tf_ll

nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean()
loss_rna = -nb_rna_ll

nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean()
loss_rna_grn = -nb_rna_grn_ll

l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1)
l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2)
l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1)
l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2)
l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1)
l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2)
l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1)
l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1)

loss_norm = (
config_file.l1_penalty_topic_tf * l1_norm_tf
+ config_file.l2_penalty_topic_tf * l2_norm_tf
+ config_file.l1_penalty_topic_peak * l1_norm_peak
+ config_file.l2_penalty_topic_peak * l2_norm_peak
+ config_file.l1_penalty_gene_peak * l1_norm_gene_peak
+ config_file.l2_penalty_gene_peak * l2_norm_gene_peak
+ config_file.l1_penalty_grn_activator * l1_norm_grn_activator
+ config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor
)

eval_loss = running_loss / max(1, nbatch)
eval_loss_atac = running_loss_atac / max(1, nbatch)
eval_loss_tf = running_loss_tf / max(1, nbatch)
eval_loss_rna = running_loss_rna / max(1, nbatch)
eval_loss_rna_grn = running_loss_rna_grn / max(1, nbatch)
total_loss = (
config_file.weight_atac_grn * loss_atac
+ config_file.weight_tf_grn * loss_tf
+ config_file.weight_rna_grn * loss_rna
+ config_file.weight_rna_from_grn * loss_rna_grn
+ loss_norm
)

return eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn
running_stats["loss"] += total_loss.item()
running_stats["loss_atac"] += loss_atac.item()
running_stats["loss_tf"] += loss_tf.item()
running_stats["loss_rna"] += loss_rna.item()
running_stats["loss_rna_grn"] += loss_rna_grn.item()
running_stats["count"] += 1

nbatch = max(1, running_stats["count"])
return (
running_stats["loss"] / nbatch
running_stats["loss_atac"] / nbatch
running_stats["loss_tf"] / nbatch
running_stats["loss_rna"] / nbatch
running_stats["loss_rna_grn"] / nbatch
)


def train_model_grn(
Expand Down Expand Up @@ -366,25 +363,10 @@ def train_model_grn(
torch.nn.Module
The trained model after the GRN phase completes or early stopping occurs.
"""
if not config_file.update_encoder_in_grn:
set_encoder_frozen(model, freeze=True)
else:
set_encoder_frozen(model, freeze=False)

if not config_file.update_peak_gene_in_grn:
set_peak_gene_frozen(model, freeze=True)
else:
set_peak_gene_frozen(model, freeze=False)

if not config_file.update_topic_peak_in_grn:
set_topic_peak_frozen(model, freeze=True)
else:
set_topic_peak_frozen(model, freeze=False)

if not config_file.update_topic_tf_in_grn:
set_topic_tf_frozen(model, freeze=True)
else:
set_topic_tf_frozen(model, freeze=False)
set_encoder_frozen(model, freeze=not config_file.update_encoder_in_grn)
set_peak_gene_frozen(model, freeze=not config_file.update_peak_gene_in_grn)
set_topic_peak_frozen(model, freeze=not config_file.update_topic_peak_in_grn)
set_topic_tf_frozen(model, freeze=not config_file.update_topic_tf_in_grn)

optimizer_grn = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=config_file.learning_rate_grn
Expand Down Expand Up @@ -412,12 +394,15 @@ def train_model_grn(
logger.info("Starting GRN training")
for epoch in range(config_file.max_grn_epochs):
model.train()
running_loss = 0.0
running_loss_atac = 0.0
running_loss_tf = 0.0
running_loss_rna = 0.0
running_loss_rna_grn = 0.0
nbatch = 0
# Initialize running stats dictionary
running_stats = {
"loss": 0.0,
"loss_atac": 0.0,
"loss_tf": 0.0,
"loss_rna": 0.0,
"loss_rna_grn": 0.0,
"count": 0,
}

# If the encoder is being updated, recalc topic_tf_input each epoch:
if config_file.update_encoder_in_grn:
Expand All @@ -443,10 +428,8 @@ def train_model_grn(
)
rna_input = input_matrix[:, : model.num_genes]
atac_input = input_matrix[:, model.num_genes :]
tf_input = tf_exp
log_lib_rna = library_size_value[:, 0].reshape(-1, 1)
log_lib_atac = library_size_value[:, 1].reshape(-1, 1)
batch_onehot = input_batch

if config_file.tf_expression_mode == "latent":
topic_tf_input = get_tf_expression(
Expand All @@ -465,18 +448,17 @@ def train_model_grn(
out = model(
rna_input,
atac_input,
tf_input,
tf_exp,
topic_tf_input,
log_lib_rna,
log_lib_atac,
num_cells_value,
batch_onehot,
input_batch,
phase="grn",
)
preds_atac = out["preds_atac"]
mu_nb_tf = out["mu_nb_tf"]
mu_nb_rna = out["mu_nb_rna"]
preds_rna_grn = out["preds_rna_from_grn"]
mu_nb_rna_grn = out["mu_nb_rna_grn"]

criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum")
Expand All @@ -485,7 +467,7 @@ def train_model_grn(
loss_atac = criterion_poisson(preds_poisson, atac_input)

alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1)
nb_tf_ll = log_nb_positive(tf_input, mu_nb_tf, alpha_tf).sum(dim=1).mean()
nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean()
loss_tf = -nb_tf_ll

alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1)
Expand Down Expand Up @@ -527,21 +509,23 @@ def train_model_grn(
total_loss.backward()
optimizer_grn.step()

running_loss += total_loss.item()
running_loss_atac += loss_atac.item()
running_loss_tf += loss_tf.item()
running_loss_rna += loss_rna.item()
running_loss_rna_grn += loss_rna_grn.item()
nbatch += 1
# Update running stats
running_stats["loss"] += total_loss.item()
running_stats["loss_atac"] += loss_atac.item()
running_stats["loss_tf"] += loss_tf.item()
running_stats["loss_rna"] += loss_rna.item()
running_stats["loss_rna_grn"] += loss_rna_grn.item()
running_stats["count"] += 1

model.gene_peak_factor_learnt.data.clamp_(min=0)
model.gene_peak_factor_learnt.data.clamp_(max=1)

epoch_loss = running_loss / max(1, nbatch)
epoch_loss_atac = running_loss_atac / max(1, nbatch)
epoch_loss_tf = running_loss_tf / max(1, nbatch)
epoch_loss_rna = running_loss_rna / max(1, nbatch)
epoch_loss_rna_grn = running_loss_rna_grn / max(1, nbatch)
nbatch = max(1, running_stats["count"])
epoch_loss = running_stats["loss"] / nbatch
epoch_loss_atac = running_stats["loss_atac"] / nbatch
epoch_loss_tf = running_stats["loss_tf"] / nbatch
epoch_loss_rna = running_stats["loss_rna"] / nbatch
epoch_loss_rna_grn = running_stats["loss_rna_grn"] / nbatch

logger.info(
f"[GRN-Train] Epoch={epoch}, Loss={epoch_loss:.4f},"
Expand Down
Loading