@@ -215,13 +215,13 @@ def _check_inputs(Z, X, y, W=None, C=None, D=None, beta=None):
215215 return Z , X , y , W , C , D , beta
216216
217217
218- def _find_roots (f , a , b , tol , max_value , max_eval , n_points = 50 ):
218+ def _find_roots (f , a , b , tol , max_value , max_eval , n_points = 50 , max_depth = 5 ):
219219 """
220- Find the root of function ``f`` between ``a`` and ``b`` closest to ``b``.
220+ Find roots of function ``f`` between ``a`` and ``b``.
221221
222222 Assumes ``f(a) < 0`` and ``f(b) > 0``. Finds root by building a grid between ``a``
223- and ``b`` with ``n_points``, evaluating ``f`` at each point, and finding the last
224- point where ``f`` is negative . If ``b`` is infinite, uses a logarithmic grid between
223+ and ``b`` with ``n_points``, evaluating ``f`` at each point, and finding indices
224+ where ``f`` switches sign . If ``b`` is infinite, uses a logarithmic grid between
225225 ``a`` and ``a + sign(b - a) * max_value``. The function is then called recursively
226226 on the new interval until the size of the interval is less than ``tol`` or the
227227 maximum number of evaluations ``max_eval`` of ``f`` is reached.
@@ -230,40 +230,57 @@ def _find_roots(f, a, b, tol, max_value, max_eval, n_points=50):
230230 closest to ``b``. Note that this is also not strictly ensured by this function.
231231 """
232232 if np .abs (b - a ) < tol or max_eval < 0 :
233- return b # conservative
233+ return [b ] # conservative, resulting in a larger interval
234+
234235 if np .isinf (a ):
235- return a
236+ return [a ]
237+
238+ roots = []
239+
236240 sgn = np .sign (b - a )
237241 if np .isinf (b ):
238242 grid = np .ones (n_points ) * a
239- grid [1 :] += sgn * np .logspace (0 , np .log10 (max_value ), n_points - 1 )
243+ grid [1 :] += sgn * np .logspace (tol , np .log10 (max_value ), n_points - 1 )
240244 else :
241245 grid = np .linspace (a , b , n_points )
242246
243247 y = np .zeros (n_points )
244- y [- 1 ] = f (grid [- 1 ])
245- if y [- 1 ] < 0 :
246- return sgn * np .inf
247248
248249 y [0 ] = f (grid [0 ])
249250 if y [0 ] >= 0 :
250251 raise ValueError ("f(a) must be negative." )
251252
252- for i , x in enumerate (grid [: - 1 ]):
253- y [i ] = f (x )
253+ for i , x in enumerate (grid [1 : ]):
254+ y [i + 1 ] = f (x )
254255
255- last_positive = np .where (y < 0 )[0 ][- 1 ]
256+ if y [- 1 ] <= 0 :
257+ roots = [b ]
256258
257- # f(a_new) < 0 < f(b_new) -> repeat
258- return _find_roots (
259- f ,
260- grid [last_positive ],
261- grid [last_positive + 1 ],
262- tol = tol ,
263- n_points = n_points ,
264- max_value = None ,
265- max_eval = max_eval - n_points ,
266- )
259+ y [y == 0 ] = np .finfo (y .dtype ).eps
260+ where = np .where (np .sign (y [:- 1 ]) != np .sign (y [1 :]))[0 ]
261+
262+ # Conservative. Focus on change closest to b.
263+ if max_depth == 0 :
264+ where = where [- 1 :]
265+
266+ for idx , w in enumerate (where ):
267+ if idx % 2 == 0 :
268+ a , b = grid [w ], grid [w + 1 ]
269+ else :
270+ a , b = grid [w + 1 ], grid [w ]
271+
272+ roots += _find_roots (
273+ f ,
274+ a ,
275+ b ,
276+ tol = tol ,
277+ n_points = n_points ,
278+ max_value = max_value ,
279+ max_eval = max_eval - n_points ,
280+ max_depth = max_depth - len (where ) > 1 ,
281+ )
282+
283+ return roots
267284
268285
269286def _characteristic_roots (a , b , subset_by_index = None ):
0 commit comments