|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | __all__ = [
|
6 |
| - "sigmoid", |
| 6 | + "binary_search", |
7 | 7 | "choose",
|
| 8 | + "clip", |
8 | 9 | "get_parameters",
|
9 |
| - "binary_search", |
| 10 | + "sigmoid", |
10 | 11 | ]
|
11 | 12 |
|
12 | 13 |
|
13 | 14 | import inspect
|
14 | 15 | from functools import lru_cache
|
| 16 | +from types import MappingProxyType |
| 17 | +from typing import Callable |
15 | 18 |
|
16 | 19 | import numpy as np
|
17 | 20 | from scipy import special
|
18 | 21 |
|
19 | 22 |
|
20 |
| -def sigmoid(x): |
21 |
| - return 1.0 / (1 + np.exp(-x)) |
| 23 | +def binary_search( |
| 24 | + function: Callable[[int | float], int | float], |
| 25 | + target: int | float, |
| 26 | + lower_bound: int | float, |
| 27 | + upper_bound: int | float, |
| 28 | + tolerance: int | float = 1e-4, |
| 29 | +) -> int | float | None: |
| 30 | + """Searches for a value in a range by repeatedly dividing the range in half. |
22 | 31 |
|
| 32 | + To be more precise, performs numerical binary search to determine the |
| 33 | + input to ``function``, between the bounds given, that outputs ``target`` |
| 34 | + to within ``tolerance`` (default of 0.0001). |
| 35 | + Returns ``None`` if no input can be found within the bounds. |
23 | 36 |
|
24 |
| -@lru_cache(maxsize=10) |
25 |
| -def choose(n, k): |
26 |
| - return special.comb(n, k, exact=True) |
27 |
| - |
| 37 | + Examples |
| 38 | + -------- |
28 | 39 |
|
29 |
| -def get_parameters(function): |
30 |
| - return inspect.signature(function).parameters |
| 40 | + Consider the polynomial :math:`x^2 + 3x + 1` where we search for |
| 41 | + a target value of :math:`11`. An exact solution is :math:`x = 2`. |
31 | 42 |
|
| 43 | + :: |
32 | 44 |
|
33 |
| -def clip(a, min_a, max_a): |
34 |
| - if a < min_a: |
35 |
| - return min_a |
36 |
| - elif a > max_a: |
37 |
| - return max_a |
38 |
| - return a |
| 45 | + >>> solution = binary_search(lambda x: x**2 + 3*x + 1, 11, 0, 5) |
| 46 | + >>> abs(solution - 2) < 1e-4 |
| 47 | + True |
| 48 | + >>> solution = binary_search(lambda x: x**2 + 3*x + 1, 11, 0, 5, tolerance=0.01) |
| 49 | + >>> abs(solution - 2) < 0.01 |
| 50 | + True |
39 | 51 |
|
| 52 | + Searching in the interval :math:`[0, 5]` for a target value of :math:`71` |
| 53 | + does not yield a solution:: |
40 | 54 |
|
41 |
| -def binary_search(function, target, lower_bound, upper_bound, tolerance=1e-4): |
| 55 | + >>> binary_search(lambda x: x**2 + 3*x + 1, 71, 0, 5) is None |
| 56 | + True |
| 57 | + """ |
42 | 58 | lh = lower_bound
|
43 | 59 | rh = upper_bound
|
| 60 | + mh = np.mean(np.array([lh, rh])) |
44 | 61 | while abs(rh - lh) > tolerance:
|
45 |
| - mh = np.mean([lh, rh]) |
| 62 | + mh = np.mean(np.array([lh, rh])) |
46 | 63 | lx, mx, rx = (function(h) for h in (lh, mh, rh))
|
47 | 64 | if lx == target:
|
48 | 65 | return lh
|
49 | 66 | if rx == target:
|
50 | 67 | return rh
|
51 | 68 |
|
52 |
| - if lx <= target and rx >= target: |
| 69 | + if lx <= target <= rx: |
53 | 70 | if mx > target:
|
54 | 71 | rh = mh
|
55 | 72 | else:
|
56 | 73 | lh = mh
|
57 |
| - elif lx > target and rx < target: |
| 74 | + elif lx > target > rx: |
58 | 75 | lh, rh = rh, lh
|
59 | 76 | else:
|
60 | 77 | return None
|
| 78 | + |
61 | 79 | return mh
|
| 80 | + |
| 81 | + |
| 82 | +@lru_cache(maxsize=10) |
| 83 | +def choose(n: int, k: int) -> int: |
| 84 | + r"""The binomial coefficient n choose k. |
| 85 | +
|
| 86 | + :math:`\binom{n}{k}` describes the number of possible choices of |
| 87 | + :math:`k` elements from a set of :math:`n` elements. |
| 88 | +
|
| 89 | + References |
| 90 | + ---------- |
| 91 | + - https://en.wikipedia.org/wiki/Combination |
| 92 | + - https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html |
| 93 | + """ |
| 94 | + return special.comb(n, k, exact=True) |
| 95 | + |
| 96 | + |
| 97 | +def clip(a, min_a, max_a): |
| 98 | + """Clips ``a`` to the interval [``min_a``, ``max_a``]. |
| 99 | +
|
| 100 | + Accepts any comparable objects (i.e. those that support <, >). |
| 101 | + Returns ``a`` if it is between ``min_a`` and ``max_a``. |
| 102 | + Otherwise, whichever of ``min_a`` and ``max_a`` is closest. |
| 103 | +
|
| 104 | + Examples |
| 105 | + -------- |
| 106 | + :: |
| 107 | +
|
| 108 | + >>> clip(15, 11, 20) |
| 109 | + 15 |
| 110 | + >>> clip('a', 'h', 'k') |
| 111 | + 'h' |
| 112 | + """ |
| 113 | + if a < min_a: |
| 114 | + return min_a |
| 115 | + elif a > max_a: |
| 116 | + return max_a |
| 117 | + return a |
| 118 | + |
| 119 | + |
| 120 | +def get_parameters(function: Callable) -> MappingProxyType[str, inspect.Parameter]: |
| 121 | + """Return the parameters of ``function`` as an ordered mapping of parameters' |
| 122 | + names to their corresponding ``Parameter`` objects. |
| 123 | +
|
| 124 | + Examples |
| 125 | + -------- |
| 126 | + :: |
| 127 | +
|
| 128 | + >>> get_parameters(get_parameters) |
| 129 | + mappingproxy(OrderedDict([('function', <Parameter "function: 'Callable'">)])) |
| 130 | +
|
| 131 | + >>> tuple(get_parameters(choose)) |
| 132 | + ('n', 'k') |
| 133 | + """ |
| 134 | + return inspect.signature(function).parameters |
| 135 | + |
| 136 | + |
| 137 | +def sigmoid(x: float) -> float: |
| 138 | + r"""Returns the output of the logistic function. |
| 139 | +
|
| 140 | + The logistic function, a common example of a sigmoid function, is defined |
| 141 | + as :math:`\frac{1}{1 + e^{-x}}`. |
| 142 | +
|
| 143 | + References |
| 144 | + ---------- |
| 145 | + - https://en.wikipedia.org/wiki/Sigmoid_function |
| 146 | + - https://en.wikipedia.org/wiki/Logistic_function |
| 147 | + """ |
| 148 | + return 1.0 / (1 + np.exp(-x)) |
0 commit comments