Skip to content

Commit 4394fd9

Browse files
Weierstrass Method (#12877)
* Add weierstrass_method for approximating complex roots - Implements Durand-Kerner (Weierstrass) method for polynomial root finding - Accepts user-defined polynomial function and degree - Uses random perturbation of complex roots of unity for initial guesses - Handles validation, overflow clipping, and includes doctest * Update weierstrass_method.py * add more tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update weierstrass_method.py * Update weierstrass_method.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 54aa73f commit 4394fd9

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
5+
6+
def weierstrass_method(
7+
polynomial: Callable[[np.ndarray], np.ndarray],
8+
degree: int,
9+
roots: np.ndarray | None = None,
10+
max_iter: int = 100,
11+
) -> np.ndarray:
12+
"""
13+
Approximates all complex roots of a polynomial using the
14+
Weierstrass (Durand-Kerner) method.
15+
Args:
16+
polynomial: A function that takes a NumPy array of complex numbers and returns
17+
the polynomial values at those points.
18+
degree: Degree of the polynomial (number of roots to find). Must be ≥ 1.
19+
roots: Optional initial guess as a NumPy array of complex numbers.
20+
Must have length equal to 'degree'.
21+
If None, perturbed complex roots of unity are used.
22+
max_iter: Number of iterations to perform (default: 100).
23+
24+
Returns:
25+
np.ndarray: Array of approximated complex roots.
26+
27+
Raises:
28+
ValueError: If degree < 1, or if initial roots length doesn't match the degree.
29+
30+
Note:
31+
- Root updates are clipped to prevent numerical overflow.
32+
33+
Example:
34+
>>> import numpy as np
35+
>>> def check(poly, degree, expected):
36+
... roots = weierstrass_method(poly, degree)
37+
... return np.allclose(np.sort(roots), np.sort(expected))
38+
39+
>>> check(
40+
... lambda x: x**2 - 1,
41+
... 2,
42+
... np.array([-1, 1]))
43+
True
44+
45+
>>> check(
46+
... lambda x: x**3 - 4.5*x**2 + 5.75*x - 1.875,
47+
... 3,
48+
... np.array([1.5, 0.5, 2.5])
49+
... )
50+
True
51+
52+
See Also:
53+
https://en.wikipedia.org/wiki/Durand%E2%80%93Kerner_method
54+
"""
55+
56+
if degree < 1:
57+
raise ValueError("Degree of the polynomial must be at least 1.")
58+
59+
if roots is None:
60+
# Use perturbed complex roots of unity as initial guesses
61+
rng = np.random.default_rng()
62+
roots = np.array(
63+
[
64+
np.exp(2j * np.pi * i / degree) * (1 + 1e-3 * rng.random())
65+
for i in range(degree)
66+
],
67+
dtype=np.complex128,
68+
)
69+
70+
else:
71+
roots = np.asarray(roots, dtype=np.complex128)
72+
if roots.shape[0] != degree:
73+
raise ValueError(
74+
"Length of initial roots must match the degree of the polynomial."
75+
)
76+
77+
for _ in range(max_iter):
78+
# Construct the product denominator for each root
79+
denominator = np.array([root - roots for root in roots], dtype=np.complex128)
80+
np.fill_diagonal(denominator, 1.0) # Avoid zero in diagonal
81+
denominator = np.prod(denominator, axis=1)
82+
83+
# Evaluate polynomial at each root
84+
numerator = polynomial(roots).astype(np.complex128)
85+
86+
# Compute update and clip to prevent overflow
87+
delta = numerator / denominator
88+
delta = np.clip(delta, -1e10, 1e10)
89+
roots -= delta
90+
91+
return roots
92+
93+
94+
if __name__ == "__main__":
95+
import doctest
96+
97+
doctest.testmod()

0 commit comments

Comments
 (0)