|
| 1 | +import contextlib |
| 2 | +from typing import Union |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import sympy as sym |
| 6 | +from numba import njit |
| 7 | + |
| 8 | +from ..utils import VectorObj |
| 9 | +from ..utils.solvers import RootFinder |
| 10 | + |
| 11 | +EPS = np.finfo("float64").resolution |
| 12 | +OMEGA = sym.Symbol(r"\omega", complex=True) |
| 13 | + |
| 14 | + |
| 15 | +def real_deocrator(fn): |
| 16 | + """Using numpy real cast is way faster than sympy.""" |
| 17 | + |
| 18 | + def wrap_fn(*args): |
| 19 | + return np.real(fn(*args)) |
| 20 | + |
| 21 | + return wrap_fn |
| 22 | + |
| 23 | + |
| 24 | +def _lu_decomposition_det(matrix: sym.Matrix): |
| 25 | + _, U, perm = matrix.LUdecomposition() |
| 26 | + # sgn = Permutation(perm).signature() |
| 27 | + sgn = 1 |
| 28 | + return sgn * U.det() |
| 29 | + |
| 30 | + |
| 31 | +def _berkowitz_det(matrix: sym.Matrix): |
| 32 | + return matrix.det(method="berkowitz") |
| 33 | + |
| 34 | + |
| 35 | +def _root_solver_numerical( |
| 36 | + root_expr, |
| 37 | + n_layers: int = None, |
| 38 | + *args, |
| 39 | + normalise_roots_by_2pi: bool = False, |
| 40 | + ftol: float = 0.01e9, |
| 41 | + max_freq: float = 80e9, |
| 42 | +): |
| 43 | + xtol = 1e-4 |
| 44 | + if n_layers is None or n_layers <= 3: |
| 45 | + # makes it faster for small systems, otherise jit cost too high |
| 46 | + y = real_deocrator(njit(sym.lambdify(OMEGA, root_expr, "math"))) |
| 47 | + else: |
| 48 | + y = real_deocrator(sym.lambdify(OMEGA, root_expr, "math")) |
| 49 | + mfreq = max_freq * 2 * np.pi if normalise_roots_by_2pi else max_freq |
| 50 | + r = RootFinder(xtol, mfreq, step=ftol, xtol=xtol, root_dtype="float16") |
| 51 | + roots = r.find(y) |
| 52 | + # convert to GHz |
| 53 | + # reduce unique solutions to 2 decimal places |
| 54 | + roots = np.unique(np.around(roots / 1e9, 2)) |
| 55 | + return roots / (2 * np.pi) if normalise_roots_by_2pi else roots |
| 56 | + |
| 57 | + |
| 58 | +# we add n_layers and args to make the signature compatible with numerical solver |
| 59 | +def _root_solver_analytical( |
| 60 | + root_expr, |
| 61 | + n_layers: int = None, |
| 62 | + normalise_roots_by_2pi: bool = False, |
| 63 | + solve_direct: bool = False, |
| 64 | + **kwargs, |
| 65 | +): |
| 66 | + # Try multiple approaches to find roots |
| 67 | + all_solutions = [] |
| 68 | + |
| 69 | + # # Approach 1: Direct solving |
| 70 | + if solve_direct: |
| 71 | + with contextlib.suppress(Exception): |
| 72 | + direct_solutions = sym.solve(root_expr, OMEGA) |
| 73 | + all_solutions.extend(direct_solutions) |
| 74 | + |
| 75 | + # Approach 2: Factorized solving |
| 76 | + with contextlib.suppress(Exception): |
| 77 | + factorised = sym.factor(root_expr) |
| 78 | + factored_solutions = sym.solve(factorised, OMEGA) |
| 79 | + all_solutions.extend(factored_solutions) |
| 80 | + |
| 81 | + # Remove duplicates |
| 82 | + all_solutions = list(set(all_solutions)) |
| 83 | + |
| 84 | + # More robust real check - evaluate numerically and check if imaginary part is negligible |
| 85 | + real_solutions = [] |
| 86 | + for sol in all_solutions: |
| 87 | + try: |
| 88 | + numeric_val = complex(sol.evalf()) |
| 89 | + if abs(numeric_val.imag) < 1e-10: # More lenient than is_real |
| 90 | + real_solutions.append(numeric_val.real) |
| 91 | + except Exception: |
| 92 | + continue |
| 93 | + |
| 94 | + # Convert to numpy and filter |
| 95 | + if not real_solutions: |
| 96 | + return np.array([]) |
| 97 | + |
| 98 | + all_roots = np.asarray(real_solutions) / 1e9 |
| 99 | + |
| 100 | + # Filter: positive and within frequency range |
| 101 | + all_roots = all_roots[(all_roots > 0)] |
| 102 | + |
| 103 | + # Remove duplicates with tolerance |
| 104 | + if len(all_roots) > 1: |
| 105 | + all_roots = np.unique(np.around(all_roots, 6)) # Higher precision than numerical |
| 106 | + |
| 107 | + return all_roots / (2 * np.pi) if normalise_roots_by_2pi else all_roots |
| 108 | + |
| 109 | + |
| 110 | +def _default_matrix_conversion( |
| 111 | + vector_or_matrix: Union[VectorObj, list[float], sym.Matrix], |
| 112 | +) -> sym.ImmutableMatrix: |
| 113 | + if isinstance(vector_or_matrix, VectorObj): |
| 114 | + vector_or_matrix = vector_or_matrix.get_cartesian() |
| 115 | + elif isinstance(vector_or_matrix, list): |
| 116 | + vector_or_matrix = sym.ImmutableMatrix(vector_or_matrix) |
| 117 | + return sym.ImmutableMatrix(vector_or_matrix) |
0 commit comments