diff --git a/src/scdori/_core/train_grn.py b/src/scdori/_core/train_grn.py index 79101f8..511e23e 100644 --- a/src/scdori/_core/train_grn.py +++ b/src/scdori/_core/train_grn.py @@ -129,40 +129,41 @@ def get_tf_expression( torch.Tensor A (num_topics x num_tfs) tensor of TF expression values for each topic. """ - if tf_expression_mode == "True": + if tf_expression_mode == "True": # FIXME latent_all_torch = get_latent_topics( model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) top_k_indices = np.argsort(latent_all_torch, axis=0)[-config_file.cells_per_topic :] rna_tf_vals = rna_anndata.X[:, tf_indices] + if sp.issparse(rna_tf_vals): - rna_tf_vals = rna_tf_vals.todense() - rna_tf_vals = np.array(rna_tf_vals) + rna_tf_vals = rna_tf_vals.toarray() + else: + rna_tf_vals = np.asarray(rna_tf_vals) median_cell = np.median(rna_tf_vals.sum(axis=1)) rna_tf_vals = median_cell * (rna_tf_vals / rna_tf_vals.sum(axis=1, keepdims=True)) topic_tf = np.array([rna_tf_vals[top_k_indices[:, t], :].mean(axis=0) for t in range(model.num_topics)]) topic_tf = torch.from_numpy(topic_tf) - preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True) - preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True) - topic_tf = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9) - topic_tf[topic_tf < config_file.tf_expression_clamp] = 0 - topic_tf = topic_tf.to(device) - return topic_tf + preds_tf_denoised_min = topic_tf.min(dim=1, keepdim=True)[0] + preds_tf_denoised_max = topic_tf.max(dim=1, keepdim=True)[0] + normalized_tf = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9) + topic_tf = normalized_tf.clamp(min=config_file.tf_expression_clamp) + return topic_tf.to(device) else: - import torch.nn as nn # Ensure this import is available if using nn.Softmax + import torch.nn as nn # Ensure this import is available if using nn.Softmax # FIXME topic_tf = nn.Softmax(dim=1)(model.decoder.topic_tf_decoder.detach().cpu()) - preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True) - preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True) + preds_tf_denoised_min = topic_tf.min(dim=1, keepdim=True)[0] + preds_tf_denoised_max = topic_tf.max(dim=1, keepdim=True)[0] tf_normalised = (topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9) - tf_normalised[tf_normalised < config_file.tf_expression_clamp] = 0 - topic_tf = tf_normalised.to(device) - return topic_tf + tf_normalised = tf_normalised.clamp(min=config_file.tf_expression_clamp) + return tf_normalised.to(device) +@torch.no_grad() def compute_eval_loss_grn( model, device, @@ -211,12 +212,7 @@ def compute_eval_loss_grn( (eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn). """ model.eval() - running_loss = 0.0 - running_loss_atac = 0.0 - running_loss_tf = 0.0 - running_loss_rna = 0.0 - running_loss_rna_grn = 0.0 - nbatch = 0 + running_stats = {"loss": 0.0, "loss_atac": 0.0, "loss_tf": 0.0, "loss_rna": 0.0, "loss_rna_grn": 0.0, "count": 0} topic_tf_input = get_tf_expression( config_file.tf_expression_mode, @@ -230,94 +226,95 @@ def compute_eval_loss_grn( encoding_batch_onehot, config_file, ) + + criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum") + alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) + alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) - with torch.no_grad(): - for batch_data in eval_loader: - cell_indices = batch_data[0].to(device) - B = cell_indices.shape[0] - - input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch( - device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot - ) - rna_input = input_matrix[:, : model.num_genes] - atac_input = input_matrix[:, model.num_genes :] - log_lib_rna = library_size_value[:, 0].reshape(-1, 1) - log_lib_atac = library_size_value[:, 1].reshape(-1, 1) - - out = model( - rna_input, - atac_input, - tf_exp, - topic_tf_input, - log_lib_rna, - log_lib_atac, - num_cells_value, - input_batch, - phase="grn", - ) - preds_atac = out["preds_atac"] - mu_nb_tf = out["mu_nb_tf"] - mu_nb_rna = out["mu_nb_rna"] - mu_nb_rna_grn = out["mu_nb_rna_grn"] - - criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum") - library_factor_peak = torch.exp(log_lib_atac.view(B, 1)) - preds_poisson = preds_atac * library_factor_peak - loss_atac = criterion_poisson(preds_poisson, atac_input) - - alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) - nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() - loss_tf = -nb_tf_ll - - alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) - nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() - loss_rna = -nb_rna_ll - - nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean() - loss_rna_grn = -nb_rna_grn_ll - - l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) - l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) - l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) - l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) - l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) - l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) - l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1) - l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1) - - loss_norm = ( - config_file.l1_penalty_topic_tf * l1_norm_tf - + config_file.l2_penalty_topic_tf * l2_norm_tf - + config_file.l1_penalty_topic_peak * l1_norm_peak - + config_file.l2_penalty_topic_peak * l2_norm_peak - + config_file.l1_penalty_gene_peak * l1_norm_gene_peak - + config_file.l2_penalty_gene_peak * l2_norm_gene_peak - + config_file.l1_penalty_grn_activator * l1_norm_grn_activator - + config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor - ) + for batch_data in eval_loader: + cell_indices = batch_data[0].to(device) + B = cell_indices.shape[0] - total_loss = ( - config_file.weight_atac_grn * loss_atac - + config_file.weight_tf_grn * loss_tf - + config_file.weight_rna_grn * loss_rna - + config_file.weight_rna_from_grn * loss_rna_grn - + loss_norm - ) - - running_loss += total_loss.item() - running_loss_atac += loss_atac.item() - running_loss_tf += loss_tf.item() - running_loss_rna += loss_rna.item() - running_loss_rna_grn += loss_rna_grn.item() - nbatch += 1 + input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch( + device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot + ) + rna_input = input_matrix[:, : model.num_genes] + atac_input = input_matrix[:, model.num_genes :] + log_lib_rna = library_size_value[:, 0].reshape(-1, 1) + log_lib_atac = library_size_value[:, 1].reshape(-1, 1) + + out = model( + rna_input, + atac_input, + tf_exp, + topic_tf_input, + log_lib_rna, + log_lib_atac, + num_cells_value, + input_batch, + phase="grn", + ) + preds_atac = out["preds_atac"] + mu_nb_tf = out["mu_nb_tf"] + mu_nb_rna = out["mu_nb_rna"] + mu_nb_rna_grn = out["mu_nb_rna_grn"] + + library_factor_peak = torch.exp(log_lib_atac.view(B, 1)) + preds_poisson = preds_atac * library_factor_peak + loss_atac = criterion_poisson(preds_poisson, atac_input) + + nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() + loss_tf = -nb_tf_ll + + nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() + loss_rna = -nb_rna_ll + + nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean() + loss_rna_grn = -nb_rna_grn_ll + + l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) + l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) + l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) + l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) + l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) + l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) + l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1) + l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1) + + loss_norm = ( + config_file.l1_penalty_topic_tf * l1_norm_tf + + config_file.l2_penalty_topic_tf * l2_norm_tf + + config_file.l1_penalty_topic_peak * l1_norm_peak + + config_file.l2_penalty_topic_peak * l2_norm_peak + + config_file.l1_penalty_gene_peak * l1_norm_gene_peak + + config_file.l2_penalty_gene_peak * l2_norm_gene_peak + + config_file.l1_penalty_grn_activator * l1_norm_grn_activator + + config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor + ) - eval_loss = running_loss / max(1, nbatch) - eval_loss_atac = running_loss_atac / max(1, nbatch) - eval_loss_tf = running_loss_tf / max(1, nbatch) - eval_loss_rna = running_loss_rna / max(1, nbatch) - eval_loss_rna_grn = running_loss_rna_grn / max(1, nbatch) + total_loss = ( + config_file.weight_atac_grn * loss_atac + + config_file.weight_tf_grn * loss_tf + + config_file.weight_rna_grn * loss_rna + + config_file.weight_rna_from_grn * loss_rna_grn + + loss_norm + ) - return eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn + running_stats["loss"] += total_loss.item() + running_stats["loss_atac"] += loss_atac.item() + running_stats["loss_tf"] += loss_tf.item() + running_stats["loss_rna"] += loss_rna.item() + running_stats["loss_rna_grn"] += loss_rna_grn.item() + running_stats["count"] += 1 + + nbatch = max(1, running_stats["count"]) + return ( + running_stats["loss"] / nbatch + running_stats["loss_atac"] / nbatch + running_stats["loss_tf"] / nbatch + running_stats["loss_rna"] / nbatch + running_stats["loss_rna_grn"] / nbatch + ) def train_model_grn( @@ -366,25 +363,10 @@ def train_model_grn( torch.nn.Module The trained model after the GRN phase completes or early stopping occurs. """ - if not config_file.update_encoder_in_grn: - set_encoder_frozen(model, freeze=True) - else: - set_encoder_frozen(model, freeze=False) - - if not config_file.update_peak_gene_in_grn: - set_peak_gene_frozen(model, freeze=True) - else: - set_peak_gene_frozen(model, freeze=False) - - if not config_file.update_topic_peak_in_grn: - set_topic_peak_frozen(model, freeze=True) - else: - set_topic_peak_frozen(model, freeze=False) - - if not config_file.update_topic_tf_in_grn: - set_topic_tf_frozen(model, freeze=True) - else: - set_topic_tf_frozen(model, freeze=False) + set_encoder_frozen(model, freeze=not config_file.update_encoder_in_grn) + set_peak_gene_frozen(model, freeze=not config_file.update_peak_gene_in_grn) + set_topic_peak_frozen(model, freeze=not config_file.update_topic_peak_in_grn) + set_topic_tf_frozen(model, freeze=not config_file.update_topic_tf_in_grn) optimizer_grn = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=config_file.learning_rate_grn @@ -412,12 +394,15 @@ def train_model_grn( logger.info("Starting GRN training") for epoch in range(config_file.max_grn_epochs): model.train() - running_loss = 0.0 - running_loss_atac = 0.0 - running_loss_tf = 0.0 - running_loss_rna = 0.0 - running_loss_rna_grn = 0.0 - nbatch = 0 + # Initialize running stats dictionary + running_stats = { + "loss": 0.0, + "loss_atac": 0.0, + "loss_tf": 0.0, + "loss_rna": 0.0, + "loss_rna_grn": 0.0, + "count": 0, + } # If the encoder is being updated, recalc topic_tf_input each epoch: if config_file.update_encoder_in_grn: @@ -443,10 +428,8 @@ def train_model_grn( ) rna_input = input_matrix[:, : model.num_genes] atac_input = input_matrix[:, model.num_genes :] - tf_input = tf_exp log_lib_rna = library_size_value[:, 0].reshape(-1, 1) log_lib_atac = library_size_value[:, 1].reshape(-1, 1) - batch_onehot = input_batch if config_file.tf_expression_mode == "latent": topic_tf_input = get_tf_expression( @@ -465,18 +448,17 @@ def train_model_grn( out = model( rna_input, atac_input, - tf_input, + tf_exp, topic_tf_input, log_lib_rna, log_lib_atac, num_cells_value, - batch_onehot, + input_batch, phase="grn", ) preds_atac = out["preds_atac"] mu_nb_tf = out["mu_nb_tf"] mu_nb_rna = out["mu_nb_rna"] - preds_rna_grn = out["preds_rna_from_grn"] mu_nb_rna_grn = out["mu_nb_rna_grn"] criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction="sum") @@ -485,7 +467,7 @@ def train_model_grn( loss_atac = criterion_poisson(preds_poisson, atac_input) alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) - nb_tf_ll = log_nb_positive(tf_input, mu_nb_tf, alpha_tf).sum(dim=1).mean() + nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() loss_tf = -nb_tf_ll alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) @@ -527,21 +509,23 @@ def train_model_grn( total_loss.backward() optimizer_grn.step() - running_loss += total_loss.item() - running_loss_atac += loss_atac.item() - running_loss_tf += loss_tf.item() - running_loss_rna += loss_rna.item() - running_loss_rna_grn += loss_rna_grn.item() - nbatch += 1 + # Update running stats + running_stats["loss"] += total_loss.item() + running_stats["loss_atac"] += loss_atac.item() + running_stats["loss_tf"] += loss_tf.item() + running_stats["loss_rna"] += loss_rna.item() + running_stats["loss_rna_grn"] += loss_rna_grn.item() + running_stats["count"] += 1 model.gene_peak_factor_learnt.data.clamp_(min=0) model.gene_peak_factor_learnt.data.clamp_(max=1) - epoch_loss = running_loss / max(1, nbatch) - epoch_loss_atac = running_loss_atac / max(1, nbatch) - epoch_loss_tf = running_loss_tf / max(1, nbatch) - epoch_loss_rna = running_loss_rna / max(1, nbatch) - epoch_loss_rna_grn = running_loss_rna_grn / max(1, nbatch) + nbatch = max(1, running_stats["count"]) + epoch_loss = running_stats["loss"] / nbatch + epoch_loss_atac = running_stats["loss_atac"] / nbatch + epoch_loss_tf = running_stats["loss_tf"] / nbatch + epoch_loss_rna = running_stats["loss_rna"] / nbatch + epoch_loss_rna_grn = running_stats["loss_rna_grn"] / nbatch logger.info( f"[GRN-Train] Epoch={epoch}, Loss={epoch_loss:.4f},"