Skip to content

Commit 9840031

Browse files
authored
Have _find_roots return multiple roots -> allow inverse tests to return union of intervals. (#107)
* Have _find_roots return multiple roots -> allow inverse tests to return union of intervals. * Changelog.
1 parent 2409e54 commit 9840031

File tree

6 files changed

+93
-48
lines changed

6 files changed

+93
-48
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ Changelog
88

99
- New test :func:`~ivmodels.tests.j.j_test` of the overidentifying restrictions.
1010

11+
- The tests :func:`~ivmodels.tests.lagrange_multiplier.inverse_lagrange_multiplier_test`
12+
and
13+
:func:`~ivmodels.tests.conditional_likelihood_ratio.inverse_conditional_likelihood_ratio_test`
14+
now possibly return unions of intervals, instead of one conservative large interval.
15+
1116
**Bug fixes:**
1217

1318
- Fixed bug in :func:`~ivmodels.models.kclass.KClass.fit` when `C` is not `None` and

benchmarks/benchmarks/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Tests:
2727
def setup(self, n, data):
2828

2929
if data == "guggenberger12 (k=10)":
30-
Z, X, y, C, W, beta = simulate_guggenberger12(
30+
Z, X, y, C, W, _, beta = simulate_guggenberger12(
3131
n=n, k=10, seed=0, return_beta=True
3232
)
3333
else:

ivmodels/tests/conditional_likelihood_ratio.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def f(x):
533533
)
534534
)
535535

536-
boundaries = []
536+
roots = []
537537
for left_upper, right_upper in cs_upper.boundaries:
538538
left_lower_, right_lower_ = None, None
539539
for left_lower, right_lower in cs_lower.boundaries:
@@ -550,24 +550,26 @@ def f(x):
550550
else:
551551
continue
552552

553-
boundaries.append(
554-
(
555-
_find_roots(
556-
f,
557-
left_lower_,
558-
left_upper,
559-
tol=tol,
560-
max_value=max_value,
561-
max_eval=max_eval,
562-
),
563-
_find_roots(
564-
f,
565-
right_lower_,
566-
right_upper,
567-
tol=tol,
568-
max_value=max_value,
569-
max_eval=max_eval,
570-
),
571-
)
553+
roots += _find_roots(
554+
f,
555+
left_lower_,
556+
left_upper,
557+
tol=tol,
558+
max_value=max_value,
559+
max_eval=max_eval,
560+
max_depth=5,
561+
)
562+
roots += _find_roots(
563+
f,
564+
right_lower_,
565+
right_upper,
566+
tol=tol,
567+
max_value=max_value,
568+
max_eval=max_eval,
572569
)
570+
571+
roots = sorted(roots)
572+
573+
assert len(roots) % 2 == 0
574+
boundaries = [(left, right) for left, right in zip(roots[::2], roots[1::2])]
573575
return ConfidenceSet(boundaries=boundaries)

ivmodels/tests/lagrange_multiplier.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,20 +490,24 @@ def inverse_lagrange_multiplier_test(
490490
else:
491491
liml = KClass(kappa="liml", fit_intercept=False).fit(W, y, Z=Z, C=D).coef_[-1]
492492

493-
left = _find_roots(
493+
roots = _find_roots(
494494
lambda x: lm.lm(x) - critical_value,
495495
a=liml,
496496
b=-np.inf,
497497
tol=tol,
498498
max_value=max_value,
499499
max_eval=max_eval,
500500
)
501-
right = _find_roots(
501+
roots += _find_roots(
502502
lambda x: lm.lm(x) - critical_value,
503503
a=liml,
504504
b=np.inf,
505505
tol=tol,
506506
max_value=max_value,
507507
max_eval=max_eval,
508508
)
509-
return ConfidenceSet(boundaries=[(left, right)])
509+
510+
roots = sorted(roots)
511+
assert len(roots) % 2 == 0
512+
boundaries = [(left, right) for left, right in zip(roots[::2], roots[1::2])]
513+
return ConfidenceSet(boundaries=boundaries)

ivmodels/utils.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ def _check_inputs(Z, X, y, W=None, C=None, D=None, beta=None):
215215
return Z, X, y, W, C, D, beta
216216

217217

218-
def _find_roots(f, a, b, tol, max_value, max_eval, n_points=50):
218+
def _find_roots(f, a, b, tol, max_value, max_eval, n_points=50, max_depth=5):
219219
"""
220-
Find the root of function ``f`` between ``a`` and ``b`` closest to ``b``.
220+
Find roots of function ``f`` between ``a`` and ``b``.
221221
222222
Assumes ``f(a) < 0`` and ``f(b) > 0``. Finds root by building a grid between ``a``
223-
and ``b`` with ``n_points``, evaluating ``f`` at each point, and finding the last
224-
point where ``f`` is negative. If ``b`` is infinite, uses a logarithmic grid between
223+
and ``b`` with ``n_points``, evaluating ``f`` at each point, and finding indices
224+
where ``f`` switches sign. If ``b`` is infinite, uses a logarithmic grid between
225225
``a`` and ``a + sign(b - a) * max_value``. The function is then called recursively
226226
on the new interval until the size of the interval is less than ``tol`` or the
227227
maximum number of evaluations ``max_eval`` of ``f`` is reached.
@@ -230,40 +230,57 @@ def _find_roots(f, a, b, tol, max_value, max_eval, n_points=50):
230230
closest to ``b``. Note that this is also not strictly ensured by this function.
231231
"""
232232
if np.abs(b - a) < tol or max_eval < 0:
233-
return b # conservative
233+
return [b] # conservative, resulting in a larger interval
234+
234235
if np.isinf(a):
235-
return a
236+
return [a]
237+
238+
roots = []
239+
236240
sgn = np.sign(b - a)
237241
if np.isinf(b):
238242
grid = np.ones(n_points) * a
239-
grid[1:] += sgn * np.logspace(0, np.log10(max_value), n_points - 1)
243+
grid[1:] += sgn * np.logspace(tol, np.log10(max_value), n_points - 1)
240244
else:
241245
grid = np.linspace(a, b, n_points)
242246

243247
y = np.zeros(n_points)
244-
y[-1] = f(grid[-1])
245-
if y[-1] < 0:
246-
return sgn * np.inf
247248

248249
y[0] = f(grid[0])
249250
if y[0] >= 0:
250251
raise ValueError("f(a) must be negative.")
251252

252-
for i, x in enumerate(grid[:-1]):
253-
y[i] = f(x)
253+
for i, x in enumerate(grid[1:]):
254+
y[i + 1] = f(x)
254255

255-
last_positive = np.where(y < 0)[0][-1]
256+
if y[-1] <= 0:
257+
roots = [b]
256258

257-
# f(a_new) < 0 < f(b_new) -> repeat
258-
return _find_roots(
259-
f,
260-
grid[last_positive],
261-
grid[last_positive + 1],
262-
tol=tol,
263-
n_points=n_points,
264-
max_value=None,
265-
max_eval=max_eval - n_points,
266-
)
259+
y[y == 0] = np.finfo(y.dtype).eps
260+
where = np.where(np.sign(y[:-1]) != np.sign(y[1:]))[0]
261+
262+
# Conservative. Focus on change closest to b.
263+
if max_depth == 0:
264+
where = where[-1:]
265+
266+
for idx, w in enumerate(where):
267+
if idx % 2 == 0:
268+
a, b = grid[w], grid[w + 1]
269+
else:
270+
a, b = grid[w + 1], grid[w]
271+
272+
roots += _find_roots(
273+
f,
274+
a,
275+
b,
276+
tol=tol,
277+
n_points=n_points,
278+
max_value=max_value,
279+
max_eval=max_eval - n_points,
280+
max_depth=max_depth - len(where) > 1,
281+
)
282+
283+
return roots
267284

268285

269286
def _characteristic_roots(a, b, subset_by_index=None):

tests/test_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import scipy
55

6-
from ivmodels.utils import _characteristic_roots, oproj, proj, to_numpy
6+
from ivmodels.utils import _characteristic_roots, _find_roots, oproj, proj, to_numpy
77

88

99
def test_proj():
@@ -132,3 +132,20 @@ def test_characteristic_roots(dim, rank):
132132
assert np.allclose(
133133
_characteristic_roots(A, B, subset_by_index=[0, 0]), np.min(finite_roots)
134134
)
135+
136+
137+
@pytest.mark.parametrize(
138+
"f, a, b, expected",
139+
[
140+
(np.sin, -1, 8, [0, np.pi, 2 * np.pi]),
141+
(lambda x: -np.sin(x), 8, -1, [0, np.pi, 2 * np.pi]),
142+
(lambda x: x**2 - 1, 0, 2, [1]),
143+
(lambda x: x**2 - 1, 0, -np.inf, [-1]),
144+
(lambda x: x**3 - x, -2, 2, [-1, 0, 1]),
145+
(lambda x: x**3 - x, 0.5, -np.inf, [-np.inf, -1, 0]),
146+
],
147+
)
148+
@pytest.mark.parametrize("tol", [1e-3, 1e-6])
149+
def test_find_roots(f, a, b, expected, tol):
150+
roots = _find_roots(f, a, b, max_value=1e6, max_eval=1e4, tol=tol)
151+
assert np.allclose(sorted(roots), expected, atol=tol)

0 commit comments

Comments
 (0)