diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index 5c44f5c..d7fc8e0 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -5,6 +5,7 @@ import scanpy as sc import scipy.sparse as sps from pacmap import PaCMAP +from parafac2.normalize import prepare_dataset from parafac2.parafac2 import parafac2_nd, store_pf2 from scipy.stats import gmean from sklearn.decomposition import PCA @@ -190,3 +191,188 @@ def fms_diff_ranks( ) return df + + +def downsample_counts_multinomial( + X: anndata.AnnData, + percent_drop: float, + random_state: int = 0, +) -> anndata.AnnData: + """ + Create a downsampled counts copy of AnnData using multinomial sampling. + + Parameters: + ----------- + X : anndata.AnnData + Input dataset + percent_drop : float + Percentage of counts to drop (0-100) + random_state : int + Random seed for reproducibility + + Returns: + -------- + anndata.AnnData + Downsampled copy of the input data + """ + import scipy.sparse as sp + + # Handle 0% drop case + if percent_drop == 0: + return X.copy() + + # Set random seed + np.random.seed(random_state) + + # Convert to CSR and extract structure + original_csr = X.X.tocsr() + data = original_csr.data.copy() + indices = original_csr.indices + indptr = original_csr.indptr + + # Process each cell + for cell_idx in range(X.n_obs): + start_idx = indptr[cell_idx] + end_idx = indptr[cell_idx + 1] + + if start_idx == end_idx: + continue + + cell_data = data[start_idx:end_idx] + total_counts = int(np.sum(cell_data)) + + if total_counts == 0: + continue + + new_total = int(total_counts * (1 - percent_drop / 100)) + if new_total == 0: + data[start_idx:end_idx] = 0 + continue + + # Convert to probabilities and normalize + probs = cell_data / total_counts + probs = probs / np.sum(probs) # Ensure sum = 1.0 + + # Multinomial sampling + new_counts = np.random.multinomial(new_total, probs) + data[start_idx:end_idx] = new_counts.astype(cell_data.dtype) + + # Create new sparse matrix + sampled_csr = sp.csr_matrix((data, indices, indptr), shape=original_csr.shape) + + # Create new AnnData object + sampled_data = X.copy() + sampled_data.X = sampled_csr + + return sampled_data + + +def calculate_fms_downsample( + X: anndata.AnnData, + X_pf2: anndata.AnnData, + percent_drop: float, + rank: int = 30, + deviance: bool = False, + condition: str = "Condition", + random_state: int = 0, +) -> float: + """ + Calculate FMS for a single downsampling scenario. + + Parameters: + ----------- + X : anndata.AnnData + Original dataset for reference + X_pf2 : anndata.AnnData + Full factorized dataset + percent_drop : float + Percentage of counts to drop (0-100) + rank : int + Factorization rank + deviance : bool + Whether to use deviance normalization + condition : str + Condition column name + random_state : int + Random seed + + Returns: + -------- + float + FMS score + """ + + # Handle 0% drop case + if percent_drop == 0: + return 1.0 + + # Create downsampled data + sampled_data = downsample_counts_multinomial( + X, percent_drop, random_state=random_state + ) + + # Apply same processing as reference + sampled_data = prepare_dataset( + sampled_data, condition, geneThreshold=0.0, deviance=deviance + ) + + # Factorization + sampledX = pf2(sampled_data, rank, random_state=random_state + 2, doEmbedding=False) + + return calculateFMS(X_pf2, sampledX) + + +def fms_percent_drop_counts( + X: anndata.AnnData, + percentList: np.ndarray, + rank: int = 30, + deviance: bool = False, + condition: str = "Condition", + geneThreshold: float = 0.0, + random_state: int = 0, +) -> pd.DataFrame: + """ + Calculate FMS for multiple downsampling percentages (single run). + + Parameters: + ----------- + X : anndata.AnnData + Input dataset + percentList : np.ndarray + Array of dropout percentages to test + rank : int + Factorization rank + deviance : bool + Whether to use deviance normalization + condition : str + Condition column name + geneThreshold : float + Gene threshold for preparation + random_state : int + Random seed + + Returns: + -------- + pd.DataFrame + DataFrame with columns: Percentage of Counts Dropped, FMS + """ + results = [] + X_prepared = prepare_dataset( + X, condition, geneThreshold=geneThreshold, deviance=deviance + ) + X_pf2 = pf2(X_prepared, rank, doEmbedding=False) + + for percent_drop in percentList: + fms_score = calculate_fms_downsample( + X=X, + X_pf2=X_pf2, + percent_drop=percent_drop, + rank=rank, + deviance=deviance, + condition=condition, + random_state=random_state, + ) + + results.append({"Percentage of Counts Dropped": percent_drop, "FMS": fms_score}) + + return pd.DataFrame(results) diff --git a/pf2rnaseq/figures/commonFuncs/plotGeneral.py b/pf2rnaseq/figures/commonFuncs/plotGeneral.py index db0f833..028328e 100644 --- a/pf2rnaseq/figures/commonFuncs/plotGeneral.py +++ b/pf2rnaseq/figures/commonFuncs/plotGeneral.py @@ -6,7 +6,12 @@ import seaborn as sns from matplotlib.axes import Axes -from ...factorization import fms_percent_drop, pf2_pca_r2x, fms_diff_ranks +from ...factorization import ( + fms_diff_ranks, + fms_percent_drop, + fms_percent_drop_counts, + pf2_pca_r2x, +) def plot_r2x(data, rank_vec, ax: Axes): @@ -460,3 +465,18 @@ def plot_fms_percent_drop( df = fms_percent_drop(X, percentList, runs, rank) sns.lineplot(data=df, x="Percentage of Data Dropped", y="FMS", ax=ax) ax.set_ylim(0, 1) + + +def plot_fms_percent_drop_counts( + X: anndata.AnnData, + ax: Axes, + percentList: np.ndarray, + rank: int = 30, + deviance: bool = False, + label: str = None, +): + """Plots FMS when dropping different percentages of data""" + df = fms_percent_drop_counts(X, percentList, rank, deviance=deviance) + sns.lineplot(data=df, x="Percentage of Counts Dropped", y="FMS", ax=ax, label=label) + ax.set_ylim(0, 1) + diff --git a/pf2rnaseq/figures/figureCountFMS.py b/pf2rnaseq/figures/figureCountFMS.py new file mode 100644 index 0000000..e69d415 --- /dev/null +++ b/pf2rnaseq/figures/figureCountFMS.py @@ -0,0 +1,31 @@ +""" +factorization score + +""" + +from anndata import read_h5ad + +from .common import getSetup, subplotLabel +from .commonFuncs.plotGeneral import plot_fms_percent_drop_counts + + +def makeFigure(): + ax, f = getSetup((6, 3), (1, 1)) + subplotLabel(ax) + # Using our cytokine dataset + X = read_h5ad("/opt/extra-storage/Treg_h5ads/Treg_raw.h5ad") + + # Remove multiplexing identifiers + X = X[:, ~X.var_names.str.match("^CMO3[0-9]{2}$")] # type: ignore + # Remove genes with too few reads now + X = X[X.X.sum(axis=1) > 10, X.X.mean(axis=0) > 0.1] + X = X.copy() + percentList = [0.0, 30.0, 50.0] + plot_fms_percent_drop_counts( + X, ax[0], percentList, rank=15, deviance=True, label="Deviance" + ) + plot_fms_percent_drop_counts( + X, ax[0], percentList, rank=15, deviance=False, label="CPM" + ) + + return f