Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 296 additions & 3 deletions blnm/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
import numpy.typing as npt
import typing

from scipy.stats import binom
from scipy.stats import binom, gamma
from scipy import special as scisp

import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture


import numpy as np
import pandas as pd
from scipy.special import logsumexp, expit
from collections import defaultdict

from . import utils
from . import dist

Expand Down Expand Up @@ -116,8 +125,136 @@ def _m_step(zeroth_order,
return coefs, means, variance


def init_pars():
raise NotImplementedError
def _find_k_from_bins(alt_allele_counts, n_counts):
ratios = [a / n for a, n in zip(alt_allele_counts, n_counts)]
ratios.sort()
bin_counts = [0] * 10
for r in ratios:
idx = 9 if r == 1.0 else int(r * 10)
bin_counts[idx] += 1
peak_indices = []
for i in range(1, 9):
if bin_counts[i] > bin_counts[i - 1] and bin_counts[i] > bin_counts[i + 1]:
peak_indices.append(i)
if not peak_indices:
if bin_counts[0] > bin_counts[1]:
peak_indices.append(0)
if bin_counts[9] > bin_counts[8]:
peak_indices.append(9)
if not peak_indices:
peak_indices = [bin_counts.index(max(bin_counts))]
return len(peak_indices)

def _find_k_gmm(d_reshaped, criterion="AIC"):
best_k = None
best_score = np.inf
for candidate_k in range(1, 5):
gmm = GaussianMixture(n_components=candidate_k, covariance_type="full", random_state=42)
gmm.fit(d_reshaped)
score = gmm.aic(d_reshaped) if criterion.upper() == "AIC" else gmm.bic(d_reshaped)
if score < best_score:
best_score = score
best_k = candidate_k
return best_k

def _plot_clusters(log_ref_clean, log_alt_clean, labels, sorted_centers):
plt.figure(figsize=(8, 6))
plt.scatter(log_ref_clean, log_alt_clean, c=labels, cmap='viridis', alpha=0.5)
plt.xlabel("log(ref)")
plt.ylabel("log(alt)")
plt.title("Scatter Plot of log(ref) vs. log(alt) with Parallel Lines")
x_vals = np.linspace(log_ref_clean.min(), log_ref_clean.max(), 200)
for c in sorted_centers:
y_vals = x_vals + c
plt.plot(x_vals, y_vals, '--', label=f'y = x + {c:.2f}')
plt.legend()
plt.show()

def init_pars(
alt_allele_counts=None,
n_counts=None,
k=None,
method=None,
do_plot=False
):
if method == "partition":
if k is None:
raise ValueError("When using method='partition', the number of components k must be specified.")
coefs = [1.0 / k] * k
means = [math.log(p / (1 - p)) for p in [(i + 1) / (k + 1) for i in range(k)]]
return {
"k": k,
"means": means,
"coefs": coefs,
"variance": 1.0
}

if alt_allele_counts is None or n_counts is None:
raise ValueError("alt_allele_counts and n_counts must be provided.")

alt_allele_counts = np.asarray(alt_allele_counts, dtype=float)
n_counts = np.asarray(n_counts, dtype=float)

if len(alt_allele_counts) != len(n_counts):
raise ValueError("alt_allele_counts and n_counts must have the same length.")
if len(alt_allele_counts) == 0:
raise ValueError("Input arrays must not be empty.")

for a, n in zip(alt_allele_counts, n_counts):
if a < 0 or n <= 0:
raise ValueError("All alt_allele_counts must be >= 0 and all n_counts > 0.")
if a > n:
raise ValueError("Each alt_allele_count must be <= the corresponding n_count.")

ref = (n_counts - alt_allele_counts).astype(float)
alt = alt_allele_counts.astype(float)
log_ref = np.log(ref)
log_alt = np.log(alt)

mask = np.isfinite(log_ref) & np.isfinite(log_alt)
log_ref_clean = log_ref[mask]
log_alt_clean = log_alt[mask]

if len(log_ref_clean) == 0:
raise ValueError("No valid (finite) data points remain after log transform.")

d = log_alt_clean - log_ref_clean
d_reshaped = d.reshape(-1, 1)

if k is None:
if method is None:
raise ValueError("If k is not specified, you must provide a method ('bins', 'AIC', or 'BIC').")
if method == "bins":
k = _find_k_from_bins(alt_allele_counts, n_counts)
elif method.upper() in ["AIC", "BIC"]:
k = _find_k_gmm(d_reshaped, criterion=method.upper())
else:
raise ValueError("method must be either 'bins', 'AIC', 'BIC', or 'partition'.")

gmm = GaussianMixture(n_components=k, covariance_type="full", random_state=42)
gmm.fit(d_reshaped)
labels = gmm.predict(d_reshaped)
centers = gmm.means_.flatten()
weights = gmm.weights_.flatten()

order = np.argsort(centers)
sorted_centers = centers[order]
sorted_weights = weights[order]
coefs = sorted_weights.tolist()

cluster_variances = np.array([gmm.covariances_[i][0][0] for i in range(k)])
sorted_var = cluster_variances[order]
overall_variance = float(np.sum((sorted_weights ** 2) * sorted_var))

if do_plot:
_plot_clusters(log_ref_clean, log_alt_clean, labels, sorted_centers)

return {
"k": k,
"means": sorted_centers.tolist(),
"coefs": coefs,
"variance": overall_variance
}


def blnm(x_counts: npt.NDArray,
Expand Down Expand Up @@ -314,3 +451,159 @@ def blnm(x_counts: npt.NDArray,
return out


#####################################Updates#############################################
def log_bln_component(x, n, mu_k, sigma, *, n_mc=1000, rng=None):
"""log 1/n_mc Σ Binom(x|n, sigmoid(S)), S~N(mu_k,σ²)"""
rng = np.random.default_rng(rng)
s = rng.normal(mu_k, sigma, size=n_mc)
p = expit(s)
return logsumexp(binom.logpmf(x, n, p)) - np.log(n_mc)

class BLNMUniformStableFixedPi:
def __init__(self, mu, sigma, weights, pi_bln=0.96):
self.mu = np.asarray(mu, dtype=float)
self.sigma = float(sigma)
self.weights = np.asarray(weights, dtype=float)
self.K = self.mu.size
self.pi_bln = float(pi_bln)
if not np.isclose(self.weights.sum(), 1.0):
raise ValueError("BLN weights must sum to 1")

def _log_mixture_pmf(self, x, n, *, n_mc=1000, rng=None):
log_uniform = np.log1p(-self.pi_bln) - np.log(n + 1)
log_components = np.array([
np.log(self.weights[k]) +
log_bln_component(x, n, self.mu[k], self.sigma,
n_mc=n_mc, rng=rng)
for k in range(self.K)
])
log_bln = np.log(self.pi_bln) + logsumexp(log_components)
return logsumexp([log_uniform, log_bln])

def sample(self, n, *, size=1, rng=None):
rng = np.random.default_rng(rng)
samples = np.empty(size, dtype=int)

for i in range(size):
if rng.random() < self.pi_bln:
k = rng.choice(self.K, p=self.weights)
s = rng.normal(self.mu[k], self.sigma)
samples[i] = rng.binomial(n, expit(s))
else:
samples[i] = rng.integers(0, n + 1)
return samples

def fit_blnm_uniform_stable(X, N, K, *,
alpha=9, beta=1,
n_init=5, max_iter=300,
tol=1e-6, n_mc_samples=1000,
rng=None, verbose=False):
"""EM for BLN mixture + Uniform (numerically stable)."""
rng = np.random.default_rng(rng)
X, N = np.asarray(X), np.asarray(N)
n_obs = X.size

best_ll, best_params = -np.inf, None

for run in range(n_init):
if verbose:
print(f"Initialisation {run+1}/{n_init}")

mu = rng.normal(0, 1, K)
sigma = abs(rng.normal(1, .3))
weights = rng.dirichlet(np.ones(K))
pi_bln = alpha / (alpha + beta)

prev_ll = -np.inf
for it in range(max_iter):

log_resp_uni = np.empty(n_obs)
log_resp_bln = np.empty((n_obs, K))

for i in range(n_obs):
log_u = np.log1p(-pi_bln) - np.log(N[i] + 1)
log_bln_k = np.array([
np.log(weights[k]) +
log_bln_component(X[i], N[i], mu[k], sigma,
n_mc=n_mc_samples, rng=rng)
for k in range(K)
])
log_bln = np.log(pi_bln) + logsumexp(log_bln_k)
log_den = logsumexp([log_u, log_bln])

log_resp_uni[i] = log_u - log_den
log_resp_bln[i] = log_bln_k + np.log(pi_bln) - log_den

resp_uni = np.exp(log_resp_uni)
resp_bln = np.exp(log_resp_bln)

weights = resp_bln.sum(0)
weights /= weights.sum()

pi_bln = (alpha - 1 + resp_bln.sum()) / (alpha + beta - 2 + n_obs)

mu_new = np.empty_like(mu)
sigma_sq = 0.0

for k in range(K):
s_k = rng.normal(mu[k], sigma, size=n_mc_samples)
p_k = expit(s_k)

log_bin = binom.logpmf(X[:, None], N[:, None], p_k)

max_log = log_bin.max(axis=1, keepdims=True)
w_lin = np.exp(log_bin - max_log)
denom = w_lin.sum(axis=1) + 1e-300
E_s = (w_lin @ s_k) / denom
E_s2 = (w_lin @ (s_k**2)) / denom

r_k = resp_bln[:, k]
mu_new[k] = (r_k * E_s).sum() / r_k.sum()
sigma_sq += (r_k * (E_s2 - 2*mu_new[k]*E_s + mu_new[k]**2)).sum()

mu = mu_new
sigma = np.sqrt(sigma_sq / resp_bln.sum())

ll = 0.0
for i in range(n_obs):
model_temp = BLNMUniformStableFixedPi(mu, sigma, weights, pi_bln)
ll += model_temp._log_mixture_pmf(X[i], N[i], n_mc=n_mc_samples, rng=rng)

if verbose and it % 10 == 0:
print(f" iter {it:3d} | log-lik = {ll: .3f}")

if np.isnan(ll):
if verbose:
print(" ‼ NaN encountered, re-start")
ll = -np.inf
break

if ll - prev_ll < tol:
break
prev_ll = ll

if ll > best_ll:
best_ll = ll
best_params = (mu.copy(), sigma, weights.copy(), pi_bln)

mu, sigma, weights, pi_bln = best_params
model = BLNMUniformStableFixedPi(mu, sigma, weights, pi_bln)

return {
"model": model,
"bln_proportion": pi_bln,
"uniform_proportion": 1 - pi_bln,
"likelihood": best_ll,
"parameters": {
"means": mu,
"sigma": sigma,
"weights": weights,
"pi_bln": pi_bln
}
}

def calculate_aic_bic(log_likelihood, n_params, n_obs):
"""Calculate AIC and BIC"""
aic = -2 * log_likelihood + 2 * n_params
bic = -2 * log_likelihood + n_params * np.log(n_obs)
return aic, bic