Skip to content

Commit 1988607

Browse files
committed
Adjust fitting
1 parent ab1c69d commit 1988607

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

pf2rnaseq/factorization.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def objective(x):
264264
mse = np.sum((A - reconstruction) ** 2)
265265

266266
# Regularization: L1 penalty on both W and H
267-
l1_W = alpha * np.sum(np.abs(W))
267+
# Exclude diagonal of W from L1 penalty
268+
l1_W = alpha * np.sum(np.abs(W)) - alpha * np.diag(np.abs(W)).sum()
268269
l1_H = alpha * np.sum(np.abs(H))
269270

270271
total_loss = mse + l1_W + l1_H
@@ -273,7 +274,7 @@ def objective(x):
273274
if total_loss < best_loss[0]:
274275
best_loss[0] = total_loss
275276

276-
if iteration_counter[0] % 100 == 0:
277+
if iteration_counter[0] % 10 == 0:
277278
print(
278279
f" Iter {iteration_counter[0]}: Loss={total_loss:.4f} "
279280
f"(MSE={mse:.4f}, L1_W={l1_W:.4f}, L1_H={l1_H:.4f})"
@@ -287,24 +288,20 @@ def gradient(x):
287288

288289
# ===== Gradient w.r.t. W =====
289290
# 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
290-
grad_W = 2 * ((W @ H - A) @ H.T) + alpha
291+
grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W)))
291292

292293
# ===== Gradient w.r.t. H =====
293294
# 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
294-
grad_H = 2 * (W.T @ (W @ H - A)) + alpha
295+
grad_H = 2 * (W.T @ (W @ H - A)) + alpha * np.sign(H)
295296

296297
return np.concatenate([grad_W.ravel(), grad_H.ravel()])
297298

298-
# Enforce non-negativity
299-
bounds = [(0, None)] * len(x0)
300-
301299
print("\nStarting optimization...")
302300

303301
result = minimize(
304302
fun=objective,
305303
x0=x0,
306304
method="L-BFGS-B",
307-
bounds=bounds,
308305
jac=gradient,
309306
options={"maxiter": max_iter, "disp": True},
310307
)

0 commit comments

Comments
 (0)