Skip to content

Commit 9fd1a6e

Browse files
committed
fix gradient computation and add some safety changes
1 parent 2cf0c67 commit 9fd1a6e

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

graphconstructor/operators/enhanced_configuration_model.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def _log(x):
2424

2525
@njit(cache=True)
2626
def _softplus(x):
27-
return np.log1p(np.exp(x))
27+
return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0.0) # avoids overflows for large x
2828

2929
@njit(cache=True)
3030
def _softplus_inv(x):
31-
return np.log(np.exp(x) - 1.0)
31+
return np.log(np.expm1(x)) # avoids unstability near 0
3232

3333
R_to_zero_to_inf = [
3434
(_exp, _log),
@@ -99,35 +99,33 @@ def _neg_log_likelihood_grad(x, y, k, s):
9999
parameters (x, y). Derived by differentiating formula (13).
100100
"""
101101
N = len(x)
102+
grad_x = np.empty(N, dtype=np.float64)
103+
grad_y = np.empty(N, dtype=np.float64)
102104

103-
grad_x = np.zeros(N)
104-
grad_y = np.zeros(N)
105-
106-
# Terms from k·log(x) and s·log(y)
105+
# Base terms: -k_i/x_i and -s_i/y_i
107106
for i in range(N):
108107
grad_x[i] = -k[i] / x[i]
109108
grad_y[i] = -s[i] / y[i]
110109

111-
# Terms from the double sum over unique pairs (i > j)
110+
# Pair contributions for i>j
112111
for i in range(N):
113112
for j in range(i):
114-
xx = x[i] * x[j]
115113
yy = y[i] * y[j]
116-
denom = 1.0 - yy + xx * yy # = 1 - yy(1 - xx)
117-
118-
# d/dx_i of log((1-yy)/denom) = -xx*yy / (denom * x[i])
119-
# (same structure for x_j by symmetry)
120-
factor_x = xx * yy / (denom * (1.0 - yy + xx * yy))
121-
grad_x[i] += factor_x / x[i]
122-
grad_x[j] += factor_x / x[j]
123-
124-
# d/dy_i of log((1-yy)/denom)
125-
# numerator deriv: -y[j] / (1-yy) — but sign flips with -log
126-
# denom deriv: (-y[j] + xx*y[j]) / denom
127-
# combined (negated for NLL):
128-
factor_y = (y[j] / (1.0 - yy)) - ((-1.0 + xx) * y[j] / denom)
129-
grad_y[i] += factor_y
130-
grad_y[j] += factor_y * y[i] / y[j] # symmetric
114+
xx = x[i] * x[j]
115+
D = 1.0 - yy + xx * yy
116+
117+
# x-grad from +log(D)
118+
# d/dx_i log(D) = (yy * x_j)/D
119+
grad_x[i] += (yy * x[j]) / D
120+
grad_x[j] += (yy * x[i]) / D
121+
122+
# y-grad from -log(1-yy) + log(D)
123+
inv1m = 1.0 / (1.0 - yy) # for -log(1-yy) term
124+
invD = 1.0 / D # for log(D) term
125+
common = (xx - 1.0) * invD # multiplier for y-derivative of log(D)
126+
127+
grad_y[i] += y[j] * inv1m + y[j] * common
128+
grad_y[j] += y[i] * inv1m + y[i] * common
131129

132130
return grad_x, grad_y
133131

@@ -287,7 +285,7 @@ def _undirected(self, G: Graph) -> Graph:
287285
y_opt = y_transform(res.x[num_nodes:])
288286

289287
# Work only on the lower triangle to avoid double computation
290-
W_lower = sp.tril(W).tocoo()
288+
W_lower = sp.tril(W, k=-1).tocoo()
291289
row = W_lower.row.astype(np.int64)
292290
col = W_lower.col.astype(np.int64)
293291
weights = W_lower.data

0 commit comments

Comments
 (0)