@@ -24,11 +24,11 @@ def _log(x):
2424
2525@njit (cache = True )
2626def _softplus (x ):
27- return np .log1p (np .exp (x ))
27+ return np .log1p (np .exp (- np . abs ( x ))) + np . maximum ( x , 0.0 ) # avoids overflows for large x
2828
2929@njit (cache = True )
3030def _softplus_inv (x ):
31- return np .log (np .exp (x ) - 1.0 )
31+ return np .log (np .expm1 (x )) # avoids unstability near 0
3232
3333R_to_zero_to_inf = [
3434 (_exp , _log ),
@@ -99,35 +99,33 @@ def _neg_log_likelihood_grad(x, y, k, s):
9999 parameters (x, y). Derived by differentiating formula (13).
100100 """
101101 N = len (x )
102+ grad_x = np .empty (N , dtype = np .float64 )
103+ grad_y = np .empty (N , dtype = np .float64 )
102104
103- grad_x = np .zeros (N )
104- grad_y = np .zeros (N )
105-
106- # Terms from k·log(x) and s·log(y)
105+ # Base terms: -k_i/x_i and -s_i/y_i
107106 for i in range (N ):
108107 grad_x [i ] = - k [i ] / x [i ]
109108 grad_y [i ] = - s [i ] / y [i ]
110109
111- # Terms from the double sum over unique pairs (i > j)
110+ # Pair contributions for i>j
112111 for i in range (N ):
113112 for j in range (i ):
114- xx = x [i ] * x [j ]
115113 yy = y [i ] * y [j ]
116- denom = 1.0 - yy + xx * yy # = 1 - yy(1 - xx)
117-
118- # d/dx_i of log((1-yy)/denom) = -xx*yy / (denom * x[i])
119- # (same structure for x_j by symmetry )
120- factor_x = xx * yy / ( denom * ( 1.0 - yy + xx * yy ))
121- grad_x [i ] += factor_x / x [i ]
122- grad_x [j ] += factor_x / x [j ]
123-
124- # d/dy_i of log(( 1-yy)/denom )
125- # numerator deriv: -y[j] / (1- yy) — but sign flips with -log
126- # denom deriv: (-y[j] + xx*y[j]) / denom
127- # combined (negated for NLL):
128- factor_y = ( y [ j ] / ( 1.0 - yy )) - (( - 1.0 + xx ) * y [ j ] / denom )
129- grad_y [i ] += factor_y
130- grad_y [j ] += factor_y * y [i ] / y [ j ] # symmetric
114+ xx = x [ i ] * x [ j ]
115+ D = 1.0 - yy + xx * yy
116+
117+ # x-grad from +log(D )
118+ # d/dx_i log(D) = ( yy * x_j)/D
119+ grad_x [i ] += ( yy * x [j ]) / D
120+ grad_x [j ] += ( yy * x [i ]) / D
121+
122+ # y-grad from - log(1-yy) + log(D )
123+ inv1m = 1.0 / (1.0 - yy ) # for -log(1-yy) term
124+ invD = 1.0 / D # for log(D) term
125+ common = ( xx - 1.0 ) * invD # multiplier for y-derivative of log(D)
126+
127+ grad_y [i ] += y [ j ] * inv1m + y [ j ] * common
128+ grad_y [j ] += y [i ] * inv1m + y [ i ] * common
131129
132130 return grad_x , grad_y
133131
@@ -287,7 +285,7 @@ def _undirected(self, G: Graph) -> Graph:
287285 y_opt = y_transform (res .x [num_nodes :])
288286
289287 # Work only on the lower triangle to avoid double computation
290- W_lower = sp .tril (W ).tocoo ()
288+ W_lower = sp .tril (W , k = - 1 ).tocoo ()
291289 row = W_lower .row .astype (np .int64 )
292290 col = W_lower .col .astype (np .int64 )
293291 weights = W_lower .data
0 commit comments