Skip to content

Commit 7b9de90

Browse files
committed
Correct SeLaLib matrix class usage
1 parent ddc221d commit 7b9de90

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

pygyro/splines/spline_interpolators.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,26 @@ def __init__(self, basis, dtype=float):
2727
self._imat = self.collocation_matrix(
2828
basis.nbasis, basis.knots, basis.degree, basis.greville, basis.periodic, basis.cubic_uniform)
2929
if basis.periodic:
30-
dmat = dia_matrix(self._imat[:-basis.degree,:-basis.degree])
31-
l = abs(dmat.offsets.min())
32-
u = dmat.offsets.max()
33-
ku = np.int32(max(l,u))
30+
self._offset = self._basis.degree // 2
31+
max_ku = basis.degree
3432
n = np.int32(basis.nbasis)
35-
if 2*ku + 1 <= n:
33+
top_size = n-max_ku
34+
if (max_ku * (n + top_size) + (3 * max_ku + 1) * top_size >= n * n):
35+
self._splu = None
36+
else:
37+
dmat = dia_matrix(self._imat[max_ku:n-max_ku,max_ku:n-max_ku])
38+
l = abs(dmat.offsets.min()-self._offset)
39+
u = dmat.offsets.max()-self._offset
40+
ku = np.int32(max(l,u))
41+
3642
self._splu = SLL.PeriodicBandedMatrix(n, ku, ku)
3743
for i in range(basis.nbasis):
38-
for j in range(basis.nbasis):
39-
if self._imat[i,j] != 0:
40-
self._splu.set_element(np.int32(i+1), np.int32(j+1), self._imat[i,j])
44+
for jmin in range(basis.nbasis):
45+
j = (jmin - self._offset) % basis.nbasis
46+
if self._imat[i,jmin] != 0:
47+
self._splu.set_element(np.int32(i+1), np.int32(j+1), self._imat[i,jmin])
4148
self._splu.factorize()
42-
else:
43-
self._splu = None
49+
4450
else:
4551
dmat = dia_matrix(self._imat)
4652
self._l = abs(dmat.offsets.min())
@@ -97,9 +103,10 @@ def _solve_system_periodic(self, ug, c):
97103
p = self._basis.degree
98104

99105
if self._splu:
100-
c[0:n] = ug
101-
self._splu.solve_inplace(c[:n])
102-
c[n:n+p] = c[0:p]
106+
c[self._offset:n+self._offset] = ug
107+
self._splu.solve_inplace(c[self._offset:n+self._offset])
108+
c[:self._offset] = c[n:n+self._offset]
109+
c[n+self._offset:] = c[self._offset:p]
103110
else:
104111
c[0:n] = np.linalg.solve(self._imat, ug)
105112
c[n:n+p] = c[0:p]

0 commit comments

Comments
 (0)