Skip to content

Commit 8841e39

Browse files
tameTNTpre-commit-ci[bot]behackl
authored
Document and type simple_functions.py (#2674)
* 🏷️ Add types to simple_functions.py * πŸ’„ Neaten binary_search() Add spacing between signature and code. Remove and expressions and address IDE warnings * πŸ“ Add docstrings for functions in simple_functions.py * 🎨 Reorder functions alphabetically * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * πŸ› Reformat code examples These were causing checks to fail due to missing spaces after `>>>` I had wanted to change these to be more consistent with iterables.py anyway. * 🎨 Change single tics to double Change \` to \`` - this ensures that the variable names are actually displayed as code (and not italics) * improved docstrings, rewrote examples as doctests * fix (???) unrelated failing doctest * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed typo * Update manim/utils/simple_functions.py Co-authored-by: Luca <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Benjamin Hackl <[email protected]>
1 parent d5f08a5 commit 8841e39

File tree

2 files changed

+110
-21
lines changed

2 files changed

+110
-21
lines changed

β€Žmanim/mobject/opengl/opengl_vectorized_mobject.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1629,7 +1629,9 @@ class OpenGLVGroup(OpenGLVMobject):
16291629
can subtract elements of a OpenGLVGroup via :meth:`~.OpenGLVGroup.remove` method, or
16301630
`-` and `-=` operators:
16311631
1632-
>>> from manim import Triangle, Square, OpenGLVGroup
1632+
>>> from manim import Triangle, Square, config
1633+
>>> config.renderer = "opengl"
1634+
>>> from manim.opengl import OpenGLVGroup
16331635
>>> vg = OpenGLVGroup()
16341636
>>> triangle, square = Triangle(), Square()
16351637
>>> vg.add(triangle)

β€Žmanim/utils/simple_functions.py

Lines changed: 107 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,59 +3,146 @@
33
from __future__ import annotations
44

55
__all__ = [
6-
"sigmoid",
6+
"binary_search",
77
"choose",
8+
"clip",
89
"get_parameters",
9-
"binary_search",
10+
"sigmoid",
1011
]
1112

1213

1314
import inspect
1415
from functools import lru_cache
16+
from types import MappingProxyType
17+
from typing import Callable
1518

1619
import numpy as np
1720
from scipy import special
1821

1922

20-
def sigmoid(x):
21-
return 1.0 / (1 + np.exp(-x))
23+
def binary_search(
24+
function: Callable[[int | float], int | float],
25+
target: int | float,
26+
lower_bound: int | float,
27+
upper_bound: int | float,
28+
tolerance: int | float = 1e-4,
29+
) -> int | float | None:
30+
"""Searches for a value in a range by repeatedly dividing the range in half.
2231
32+
To be more precise, performs numerical binary search to determine the
33+
input to ``function``, between the bounds given, that outputs ``target``
34+
to within ``tolerance`` (default of 0.0001).
35+
Returns ``None`` if no input can be found within the bounds.
2336
24-
@lru_cache(maxsize=10)
25-
def choose(n, k):
26-
return special.comb(n, k, exact=True)
27-
37+
Examples
38+
--------
2839
29-
def get_parameters(function):
30-
return inspect.signature(function).parameters
40+
Consider the polynomial :math:`x^2 + 3x + 1` where we search for
41+
a target value of :math:`11`. An exact solution is :math:`x = 2`.
3142
43+
::
3244
33-
def clip(a, min_a, max_a):
34-
if a < min_a:
35-
return min_a
36-
elif a > max_a:
37-
return max_a
38-
return a
45+
>>> solution = binary_search(lambda x: x**2 + 3*x + 1, 11, 0, 5)
46+
>>> abs(solution - 2) < 1e-4
47+
True
48+
>>> solution = binary_search(lambda x: x**2 + 3*x + 1, 11, 0, 5, tolerance=0.01)
49+
>>> abs(solution - 2) < 0.01
50+
True
3951
52+
Searching in the interval :math:`[0, 5]` for a target value of :math:`71`
53+
does not yield a solution::
4054
41-
def binary_search(function, target, lower_bound, upper_bound, tolerance=1e-4):
55+
>>> binary_search(lambda x: x**2 + 3*x + 1, 71, 0, 5) is None
56+
True
57+
"""
4258
lh = lower_bound
4359
rh = upper_bound
60+
mh = np.mean(np.array([lh, rh]))
4461
while abs(rh - lh) > tolerance:
45-
mh = np.mean([lh, rh])
62+
mh = np.mean(np.array([lh, rh]))
4663
lx, mx, rx = (function(h) for h in (lh, mh, rh))
4764
if lx == target:
4865
return lh
4966
if rx == target:
5067
return rh
5168

52-
if lx <= target and rx >= target:
69+
if lx <= target <= rx:
5370
if mx > target:
5471
rh = mh
5572
else:
5673
lh = mh
57-
elif lx > target and rx < target:
74+
elif lx > target > rx:
5875
lh, rh = rh, lh
5976
else:
6077
return None
78+
6179
return mh
80+
81+
82+
@lru_cache(maxsize=10)
83+
def choose(n: int, k: int) -> int:
84+
r"""The binomial coefficient n choose k.
85+
86+
:math:`\binom{n}{k}` describes the number of possible choices of
87+
:math:`k` elements from a set of :math:`n` elements.
88+
89+
References
90+
----------
91+
- https://en.wikipedia.org/wiki/Combination
92+
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html
93+
"""
94+
return special.comb(n, k, exact=True)
95+
96+
97+
def clip(a, min_a, max_a):
98+
"""Clips ``a`` to the interval [``min_a``, ``max_a``].
99+
100+
Accepts any comparable objects (i.e. those that support <, >).
101+
Returns ``a`` if it is between ``min_a`` and ``max_a``.
102+
Otherwise, whichever of ``min_a`` and ``max_a`` is closest.
103+
104+
Examples
105+
--------
106+
::
107+
108+
>>> clip(15, 11, 20)
109+
15
110+
>>> clip('a', 'h', 'k')
111+
'h'
112+
"""
113+
if a < min_a:
114+
return min_a
115+
elif a > max_a:
116+
return max_a
117+
return a
118+
119+
120+
def get_parameters(function: Callable) -> MappingProxyType[str, inspect.Parameter]:
121+
"""Return the parameters of ``function`` as an ordered mapping of parameters'
122+
names to their corresponding ``Parameter`` objects.
123+
124+
Examples
125+
--------
126+
::
127+
128+
>>> get_parameters(get_parameters)
129+
mappingproxy(OrderedDict([('function', <Parameter "function: 'Callable'">)]))
130+
131+
>>> tuple(get_parameters(choose))
132+
('n', 'k')
133+
"""
134+
return inspect.signature(function).parameters
135+
136+
137+
def sigmoid(x: float) -> float:
138+
r"""Returns the output of the logistic function.
139+
140+
The logistic function, a common example of a sigmoid function, is defined
141+
as :math:`\frac{1}{1 + e^{-x}}`.
142+
143+
References
144+
----------
145+
- https://en.wikipedia.org/wiki/Sigmoid_function
146+
- https://en.wikipedia.org/wiki/Logistic_function
147+
"""
148+
return 1.0 / (1 + np.exp(-x))

0 commit comments

Comments
Β (0)