Skip to content

Commit 91d6418

Browse files
TB: added option for periodic grids to sparse interpolation matrices (#19)
* Added option for periodic grids to sparse interpolation matrices * Using camel case
1 parent 672d621 commit 91d6418

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

qmat/lagrange.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def getDerivationMatrix(self, *args, **kwargs):
537537
return self.getDerivativeMatrix(*args, **kwargs)
538538

539539

540-
def getSparseInterpolationMatrix(inPoints, outPoints, order):
540+
def getSparseInterpolationMatrix(inPoints, outPoints, order, gridPeriod=-1):
541541
"""
542542
Get a sparse interpolation matrix from `inPoints` to `outPoints` of order
543543
`order` using barycentric Lagrange interpolation.
@@ -553,6 +553,8 @@ def getSparseInterpolationMatrix(inPoints, outPoints, order):
553553
The points you want to interpolate to
554554
order : int
555555
Order of the interpolation
556+
grid_period : float
557+
Period of the grid. Negative values indicate non-periodic grids
556558
557559
Returns
558560
-------
@@ -568,10 +570,21 @@ def getSparseInterpolationMatrix(inPoints, outPoints, order):
568570
lastClosestPoints = None
569571

570572
for i in range(len(outPoints)):
571-
closestPointsIdx = np.sort(np.argsort(np.abs(inPoints - outPoints[i]))[:order])
572-
closestPoints = inPoints[closestPointsIdx] - outPoints[i]
573+
if gridPeriod > 0:
574+
pathL = (inPoints - gridPeriod - outPoints[i] % gridPeriod)
575+
pathR = (inPoints + gridPeriod - outPoints[i] % gridPeriod)
576+
pathC = (inPoints - outPoints[i] % gridPeriod)
577+
path = np.append(np.append(pathR, pathL), pathC)
578+
dist = np.abs(path)
579+
_closestPointsIdx = np.sort(np.argsort(dist)[:order])
580+
closestPointsIdx = _closestPointsIdx % len(inPoints)
581+
closestPoints, sorting = np.unique(path[_closestPointsIdx], return_index=True)
582+
closestPointsIdx = closestPointsIdx[sorting]
583+
else:
584+
closestPointsIdx = np.sort(np.argsort(np.abs(inPoints - outPoints[i]))[:order])
585+
closestPoints = inPoints[closestPointsIdx] - outPoints[i]
573586

574-
if lastClosestPoints is not None and np.allclose(closestPoints, lastClosestPoints):
587+
if lastClosestPoints is not None and len(closestPoints) == len(lastClosestPoints) and np.allclose(closestPoints, lastClosestPoints):
575588
interpolationLine = lastInterpolationLine
576589
else:
577590
interpolator = LagrangeApproximation(points = closestPoints)

tests/test_2_lagrange.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,31 @@ def testSparseInterpolation(inPoints, outPoints, order):
203203
inPolynomial = np.polyval(polyCoeffs,inPoints)
204204
interpolated = interpolationMatrix @ inPolynomial
205205
assert not np.allclose(np.polyval(polyCoeffs, outPoints), interpolated), f'Interpolation of order {order+1} polynomial is unexpectedly exact'
206+
207+
208+
@pytest.mark.parametrize('inPoints', [np.linspace(0, 1, 64, endpoint=False)])
209+
@pytest.mark.parametrize('outPoints', [np.linspace(0.123, 4.1214, 256, endpoint=False), np.linspace(0, 1, 128, endpoint=False)])
210+
@pytest.mark.parametrize('order', [2, 3, 4])
211+
def testSparseInterpolationPeriodic(inPoints, outPoints, order, gridPeriod=1):
212+
"""
213+
In this test, we do "extrapolation", which is really interpolation because of the periodicity of the grid.
214+
"""
215+
from qmat.lagrange import getSparseInterpolationMatrix
216+
import scipy.sparse as sp
217+
218+
data = np.sin(inPoints * 2 * np.pi / gridPeriod)
219+
220+
interpolationMatrix = getSparseInterpolationMatrix(inPoints, outPoints, order, gridPeriod=gridPeriod)
221+
assert isinstance(interpolationMatrix, sp.csc_matrix)
222+
223+
224+
interpolated = interpolationMatrix @ data
225+
ref = np.sin(outPoints * 2* np.pi / gridPeriod)
226+
error = np.linalg.norm(interpolated - ref)
227+
228+
max_error = {
229+
4: 2e-5,
230+
3: 4e-4,
231+
2: 8e-3,
232+
}
233+
assert error < max_error[order], error

0 commit comments

Comments
 (0)