|
1 | 1 | __all__ = ["FirstDerivative"] |
2 | 2 |
|
3 | | -from typing import Union |
| 3 | +from typing import Callable, Union |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | from numpy.core.multiarray import normalize_axis_index |
@@ -100,40 +100,41 @@ def __init__( |
100 | 100 | self.kind = kind |
101 | 101 | self.edge = edge |
102 | 102 | self.order = order |
| 103 | + self._register_multiplications(self.kind, self.order) |
103 | 104 |
|
104 | | - def _matvec(self, x: NDArray) -> NDArray: |
105 | | - if self.kind == "forward": |
106 | | - return self._matvec_forward(x) |
107 | | - elif self.kind == "backward": |
108 | | - return self._matvec_backward(x) |
109 | | - elif self.kind == "centered": |
110 | | - if self.order == 3: |
111 | | - return self._matvec_centered3(x) |
112 | | - elif self.order == 5: |
113 | | - return self._matvec_centered5(x) |
| 105 | + def _register_multiplications( |
| 106 | + self, |
| 107 | + kind: str, |
| 108 | + order: int, |
| 109 | + ) -> None: |
| 110 | + # choose _matvec and _rmatvec kind |
| 111 | + self._hmatvec: Callable |
| 112 | + self._hrmatvec: Callable |
| 113 | + if kind == "forward": |
| 114 | + self._hmatvec = self._matvec_forward |
| 115 | + self._hrmatvec = self._rmatvec_forward |
| 116 | + elif kind == "centered": |
| 117 | + if order == 3: |
| 118 | + self._hmatvec = self._matvec_centered3 |
| 119 | + self._hrmatvec = self._rmatvec_centered3 |
| 120 | + elif order == 5: |
| 121 | + self._hmatvec = self._matvec_centered5 |
| 122 | + self._hrmatvec = self._rmatvec_centered5 |
114 | 123 | else: |
115 | 124 | raise NotImplementedError("'order' must be '3, or '5'") |
| 125 | + elif kind == "backward": |
| 126 | + self._hmatvec = self._matvec_backward |
| 127 | + self._hrmatvec = self._rmatvec_backward |
116 | 128 | else: |
117 | 129 | raise NotImplementedError( |
118 | | - "'kind' must be 'forward', 'centered' or 'backward'" |
| 130 | + "'kind' must be 'forward', 'centered', or 'backward'" |
119 | 131 | ) |
120 | 132 |
|
| 133 | + def _matvec(self, x: NDArray) -> NDArray: |
| 134 | + return self._hmatvec(x) |
| 135 | + |
121 | 136 | def _rmatvec(self, x: NDArray) -> NDArray: |
122 | | - if self.kind == "forward": |
123 | | - return self._rmatvec_forward(x) |
124 | | - elif self.kind == "backward": |
125 | | - return self._rmatvec_backward(x) |
126 | | - elif self.kind == "centered": |
127 | | - if self.order == 3: |
128 | | - return self._rmatvec_centered3(x) |
129 | | - elif self.order == 5: |
130 | | - return self._rmatvec_centered5(x) |
131 | | - else: |
132 | | - raise NotImplementedError("'order' must be '3, or '5'") |
133 | | - else: |
134 | | - raise NotImplementedError( |
135 | | - "'kind' must be 'forward', 'centered' or 'backward'" |
136 | | - ) |
| 137 | + return self._hrmatvec(x) |
137 | 138 |
|
138 | 139 | @reshaped(swapaxis=True) |
139 | 140 | def _matvec_forward(self, x: NDArray) -> NDArray: |
|
0 commit comments