diff --git a/blnm/fit.py b/blnm/fit.py index 199e14c..2cacf2c 100644 --- a/blnm/fit.py +++ b/blnm/fit.py @@ -4,9 +4,18 @@ import numpy.typing as npt import typing -from scipy.stats import binom +from scipy.stats import binom, gamma from scipy import special as scisp +import matplotlib.pyplot as plt +from sklearn.mixture import GaussianMixture + + +import numpy as np +import pandas as pd +from scipy.special import logsumexp, expit +from collections import defaultdict + from . import utils from . import dist @@ -116,8 +125,136 @@ def _m_step(zeroth_order, return coefs, means, variance -def init_pars(): - raise NotImplementedError +def _find_k_from_bins(alt_allele_counts, n_counts): + ratios = [a / n for a, n in zip(alt_allele_counts, n_counts)] + ratios.sort() + bin_counts = [0] * 10 + for r in ratios: + idx = 9 if r == 1.0 else int(r * 10) + bin_counts[idx] += 1 + peak_indices = [] + for i in range(1, 9): + if bin_counts[i] > bin_counts[i - 1] and bin_counts[i] > bin_counts[i + 1]: + peak_indices.append(i) + if not peak_indices: + if bin_counts[0] > bin_counts[1]: + peak_indices.append(0) + if bin_counts[9] > bin_counts[8]: + peak_indices.append(9) + if not peak_indices: + peak_indices = [bin_counts.index(max(bin_counts))] + return len(peak_indices) + +def _find_k_gmm(d_reshaped, criterion="AIC"): + best_k = None + best_score = np.inf + for candidate_k in range(1, 5): + gmm = GaussianMixture(n_components=candidate_k, covariance_type="full", random_state=42) + gmm.fit(d_reshaped) + score = gmm.aic(d_reshaped) if criterion.upper() == "AIC" else gmm.bic(d_reshaped) + if score < best_score: + best_score = score + best_k = candidate_k + return best_k + +def _plot_clusters(log_ref_clean, log_alt_clean, labels, sorted_centers): + plt.figure(figsize=(8, 6)) + plt.scatter(log_ref_clean, log_alt_clean, c=labels, cmap='viridis', alpha=0.5) + plt.xlabel("log(ref)") + plt.ylabel("log(alt)") + plt.title("Scatter Plot of log(ref) vs. log(alt) with Parallel Lines") + x_vals = np.linspace(log_ref_clean.min(), log_ref_clean.max(), 200) + for c in sorted_centers: + y_vals = x_vals + c + plt.plot(x_vals, y_vals, '--', label=f'y = x + {c:.2f}') + plt.legend() + plt.show() + +def init_pars( + alt_allele_counts=None, + n_counts=None, + k=None, + method=None, + do_plot=False +): + if method == "partition": + if k is None: + raise ValueError("When using method='partition', the number of components k must be specified.") + coefs = [1.0 / k] * k + means = [math.log(p / (1 - p)) for p in [(i + 1) / (k + 1) for i in range(k)]] + return { + "k": k, + "means": means, + "coefs": coefs, + "variance": 1.0 + } + + if alt_allele_counts is None or n_counts is None: + raise ValueError("alt_allele_counts and n_counts must be provided.") + + alt_allele_counts = np.asarray(alt_allele_counts, dtype=float) + n_counts = np.asarray(n_counts, dtype=float) + + if len(alt_allele_counts) != len(n_counts): + raise ValueError("alt_allele_counts and n_counts must have the same length.") + if len(alt_allele_counts) == 0: + raise ValueError("Input arrays must not be empty.") + + for a, n in zip(alt_allele_counts, n_counts): + if a < 0 or n <= 0: + raise ValueError("All alt_allele_counts must be >= 0 and all n_counts > 0.") + if a > n: + raise ValueError("Each alt_allele_count must be <= the corresponding n_count.") + + ref = (n_counts - alt_allele_counts).astype(float) + alt = alt_allele_counts.astype(float) + log_ref = np.log(ref) + log_alt = np.log(alt) + + mask = np.isfinite(log_ref) & np.isfinite(log_alt) + log_ref_clean = log_ref[mask] + log_alt_clean = log_alt[mask] + + if len(log_ref_clean) == 0: + raise ValueError("No valid (finite) data points remain after log transform.") + + d = log_alt_clean - log_ref_clean + d_reshaped = d.reshape(-1, 1) + + if k is None: + if method is None: + raise ValueError("If k is not specified, you must provide a method ('bins', 'AIC', or 'BIC').") + if method == "bins": + k = _find_k_from_bins(alt_allele_counts, n_counts) + elif method.upper() in ["AIC", "BIC"]: + k = _find_k_gmm(d_reshaped, criterion=method.upper()) + else: + raise ValueError("method must be either 'bins', 'AIC', 'BIC', or 'partition'.") + + gmm = GaussianMixture(n_components=k, covariance_type="full", random_state=42) + gmm.fit(d_reshaped) + labels = gmm.predict(d_reshaped) + centers = gmm.means_.flatten() + weights = gmm.weights_.flatten() + + order = np.argsort(centers) + sorted_centers = centers[order] + sorted_weights = weights[order] + coefs = sorted_weights.tolist() + + cluster_variances = np.array([gmm.covariances_[i][0][0] for i in range(k)]) + sorted_var = cluster_variances[order] + overall_variance = float(np.sum((sorted_weights ** 2) * sorted_var)) + + if do_plot: + _plot_clusters(log_ref_clean, log_alt_clean, labels, sorted_centers) + + return { + "k": k, + "means": sorted_centers.tolist(), + "coefs": coefs, + "variance": overall_variance + } def blnm(x_counts: npt.NDArray, @@ -314,3 +451,159 @@ def blnm(x_counts: npt.NDArray, return out +#####################################Updates############################################# +def log_bln_component(x, n, mu_k, sigma, *, n_mc=1000, rng=None): + """log 1/n_mc Σ Binom(x|n, sigmoid(S)), S~N(mu_k,σ²)""" + rng = np.random.default_rng(rng) + s = rng.normal(mu_k, sigma, size=n_mc) + p = expit(s) + return logsumexp(binom.logpmf(x, n, p)) - np.log(n_mc) + +class BLNMUniformStableFixedPi: + def __init__(self, mu, sigma, weights, pi_bln=0.96): + self.mu = np.asarray(mu, dtype=float) + self.sigma = float(sigma) + self.weights = np.asarray(weights, dtype=float) + self.K = self.mu.size + self.pi_bln = float(pi_bln) + if not np.isclose(self.weights.sum(), 1.0): + raise ValueError("BLN weights must sum to 1") + + def _log_mixture_pmf(self, x, n, *, n_mc=1000, rng=None): + log_uniform = np.log1p(-self.pi_bln) - np.log(n + 1) + log_components = np.array([ + np.log(self.weights[k]) + + log_bln_component(x, n, self.mu[k], self.sigma, + n_mc=n_mc, rng=rng) + for k in range(self.K) + ]) + log_bln = np.log(self.pi_bln) + logsumexp(log_components) + return logsumexp([log_uniform, log_bln]) + + def sample(self, n, *, size=1, rng=None): + rng = np.random.default_rng(rng) + samples = np.empty(size, dtype=int) + + for i in range(size): + if rng.random() < self.pi_bln: + k = rng.choice(self.K, p=self.weights) + s = rng.normal(self.mu[k], self.sigma) + samples[i] = rng.binomial(n, expit(s)) + else: + samples[i] = rng.integers(0, n + 1) + return samples + +def fit_blnm_uniform_stable(X, N, K, *, + alpha=9, beta=1, + n_init=5, max_iter=300, + tol=1e-6, n_mc_samples=1000, + rng=None, verbose=False): + """EM for BLN mixture + Uniform (numerically stable).""" + rng = np.random.default_rng(rng) + X, N = np.asarray(X), np.asarray(N) + n_obs = X.size + + best_ll, best_params = -np.inf, None + + for run in range(n_init): + if verbose: + print(f"Initialisation {run+1}/{n_init}") + + mu = rng.normal(0, 1, K) + sigma = abs(rng.normal(1, .3)) + weights = rng.dirichlet(np.ones(K)) + pi_bln = alpha / (alpha + beta) + + prev_ll = -np.inf + for it in range(max_iter): + + log_resp_uni = np.empty(n_obs) + log_resp_bln = np.empty((n_obs, K)) + + for i in range(n_obs): + log_u = np.log1p(-pi_bln) - np.log(N[i] + 1) + log_bln_k = np.array([ + np.log(weights[k]) + + log_bln_component(X[i], N[i], mu[k], sigma, + n_mc=n_mc_samples, rng=rng) + for k in range(K) + ]) + log_bln = np.log(pi_bln) + logsumexp(log_bln_k) + log_den = logsumexp([log_u, log_bln]) + + log_resp_uni[i] = log_u - log_den + log_resp_bln[i] = log_bln_k + np.log(pi_bln) - log_den + + resp_uni = np.exp(log_resp_uni) + resp_bln = np.exp(log_resp_bln) + + weights = resp_bln.sum(0) + weights /= weights.sum() + + pi_bln = (alpha - 1 + resp_bln.sum()) / (alpha + beta - 2 + n_obs) + + mu_new = np.empty_like(mu) + sigma_sq = 0.0 + + for k in range(K): + s_k = rng.normal(mu[k], sigma, size=n_mc_samples) + p_k = expit(s_k) + + log_bin = binom.logpmf(X[:, None], N[:, None], p_k) + + max_log = log_bin.max(axis=1, keepdims=True) + w_lin = np.exp(log_bin - max_log) + denom = w_lin.sum(axis=1) + 1e-300 + E_s = (w_lin @ s_k) / denom + E_s2 = (w_lin @ (s_k**2)) / denom + + r_k = resp_bln[:, k] + mu_new[k] = (r_k * E_s).sum() / r_k.sum() + sigma_sq += (r_k * (E_s2 - 2*mu_new[k]*E_s + mu_new[k]**2)).sum() + + mu = mu_new + sigma = np.sqrt(sigma_sq / resp_bln.sum()) + + ll = 0.0 + for i in range(n_obs): + model_temp = BLNMUniformStableFixedPi(mu, sigma, weights, pi_bln) + ll += model_temp._log_mixture_pmf(X[i], N[i], n_mc=n_mc_samples, rng=rng) + + if verbose and it % 10 == 0: + print(f" iter {it:3d} | log-lik = {ll: .3f}") + + if np.isnan(ll): + if verbose: + print(" ‼ NaN encountered, re-start") + ll = -np.inf + break + + if ll - prev_ll < tol: + break + prev_ll = ll + + if ll > best_ll: + best_ll = ll + best_params = (mu.copy(), sigma, weights.copy(), pi_bln) + + mu, sigma, weights, pi_bln = best_params + model = BLNMUniformStableFixedPi(mu, sigma, weights, pi_bln) + + return { + "model": model, + "bln_proportion": pi_bln, + "uniform_proportion": 1 - pi_bln, + "likelihood": best_ll, + "parameters": { + "means": mu, + "sigma": sigma, + "weights": weights, + "pi_bln": pi_bln + } + } + +def calculate_aic_bic(log_likelihood, n_params, n_obs): + """Calculate AIC and BIC""" + aic = -2 * log_likelihood + 2 * n_params + bic = -2 * log_likelihood + n_params * np.log(n_obs) + return aic, bic