Skip to content

Commit a8bdeb4

Browse files
alexfiklinducer
authored andcommitted
fix: allow sequences for center
1 parent 580ef67 commit a8bdeb4

File tree

3 files changed

+10
-25
lines changed

3 files changed

+10
-25
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13687,22 +13687,6 @@
1368713687
"lineCount": 1
1368813688
}
1368913689
},
13690-
{
13691-
"code": "reportUnknownMemberType",
13692-
"range": {
13693-
"startColumn": 24,
13694-
"endColumn": 50,
13695-
"lineCount": 1
13696-
}
13697-
},
13698-
{
13699-
"code": "reportUnknownMemberType",
13700-
"range": {
13701-
"startColumn": 25,
13702-
"endColumn": 52,
13703-
"lineCount": 1
13704-
}
13705-
},
1370613690
{
1370713691
"code": "reportUnknownArgumentType",
1370813692
"range": {

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"Array1D": "class:numpy.ndarray",
3838
"Array2D": "class:numpy.ndarray",
3939
"ArrayND": "class:numpy.ndarray",
40+
"ToArray1D": "class:numpy.ndarray",
4041
"np.floating": "class:numpy.floating",
4142
"np.complexfloating": "class:numpy.complexfloating",
4243
"np.inexact": "class:numpy.inexact",

sumpy/point_calculus.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
if TYPE_CHECKING:
3636
from collections.abc import Callable, Sequence
3737

38-
from optype.numpy import Array1D, Array2D, ArrayND
38+
from optype.numpy import Array1D, Array2D, ArrayND, ToArray1D
3939

4040

4141
__doc__ = """
@@ -99,28 +99,28 @@ class CalculusPatch:
9999
_pshape: tuple[int, ...]
100100

101101
def __init__(self,
102-
center: Array1D[np.floating[Any]],
102+
center: ToArray1D[np.floating[Any]],
103103
h: float = 1e-1,
104104
order: int = 4,
105105
nodes: NodesKind = "chebyshev") -> None:
106-
self.center = center
107-
dtype = center.dtype
106+
center = np.asarray(center)
107+
assert center.ndim == 1
108108

109109
npoints = order + 1
110110
if nodes == "equispaced":
111-
points_1d = np.linspace(-h/2, h/2, npoints, dtype=dtype)
111+
points_1d = np.linspace(-h/2, h/2, npoints)
112112
weights_1d = None
113113

114114
elif nodes == "chebyshev":
115-
a = np.arange(npoints, dtype=dtype)
115+
a = np.arange(npoints)
116116
points_1d = (h/2)*np.cos((2*(a+1)-1)/(2*npoints)*np.pi)
117117
weights_1d = None
118118

119119
elif nodes == "legendre":
120120
from scipy.special import legendre
121121
points_1d, weights_1d, _ = legendre(npoints).weights.T
122-
points_1d = (points_1d * (h/2)).astype(dtype)
123-
weights_1d = (weights_1d * (h/2)).astype(dtype)
122+
points_1d = points_1d * (h/2)
123+
weights_1d = weights_1d * (h/2)
124124

125125
else:
126126
raise ValueError(f"invalid node set: {nodes}")
@@ -130,8 +130,8 @@ def __init__(self,
130130
self._points_1d = points_1d
131131
self._weights_1d = weights_1d
132132

133-
self.dim = dim = len(self.center)
134133
self.center = center
134+
self.dim = dim = len(self.center)
135135

136136
points_shaped = np.array(np.meshgrid(
137137
*[center[i] + points_1d for i in range(dim)],

0 commit comments

Comments
 (0)