Skip to content

Commit 070eed6

Browse files
committed
TL: added interpolation routines
1 parent 6106b4f commit 070eed6

File tree

1 file changed

+136
-3
lines changed
  • pySDC/playgrounds/dedalus/problems

1 file changed

+136
-3
lines changed

pySDC/playgrounds/dedalus/problems/rbc.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import scipy.optimize as sco
1717

1818
from qmat.lagrange import LagrangeApproximation
19+
from qmat.nodes import NodesGenerator
1920
from pySDC.helpers.fieldsIO import Rectilinear
2021
from pySDC.helpers.blocks import BlockDecomposition
2122
from pySDC.playgrounds.dedalus.timestepper import SDCIMEX, SDCIMEX_MPI, SDCIMEX_MPI2
@@ -889,6 +890,138 @@ def toVTR(self, idxFormat="{:06d}"):
889890
u = self.readFieldAt(i)
890891
writeToVTR(template.format(i), u, coords, varNames)
891892

893+
894+
def toPySDC(self, fileName:str, iField:int=-1, interpolate=False):
895+
"""
896+
Convert a given field into a pySDC format,
897+
eventually interpolating to twice the space resolution.
898+
899+
900+
Parameters
901+
----------
902+
fileName : str
903+
Name of the output file.
904+
iField : int, optional
905+
Index of the field to write in file. The default is -1.
906+
interpolate : TYPE, optional
907+
Wether or not interpolate the field. The default is False.
908+
"""
909+
fields = self.file["tasks"]
910+
911+
velocity = fields["velocity"][iField]
912+
buoyancy = fields["buoyancy"][iField]
913+
pressure = fields["pressure"][iField]
914+
915+
uAll = np.concat([velocity, buoyancy[None, ...], pressure[None, ...]])
916+
917+
if self.dim == 3:
918+
if interpolate:
919+
uAll = rbc3dInterpolation(uAll)
920+
nX, nY, nZ = uAll.shape[1:]
921+
xCoord = np.linspace(0, 1, nX, endpoint=False)
922+
yCoord = np.linspace(0, 1, nY, endpoint=False)
923+
zCoord = NodesGenerator("CHEBY-1", "GAUSS").getNodes(nZ)
924+
header = (5, [xCoord, yCoord, zCoord])
925+
926+
elif self.dim == 2:
927+
if interpolate:
928+
uAll = rbc2dInterpolation(uAll)
929+
nX, nZ = uAll.shape[1:]
930+
xCoord = np.linspace(0, 1, nX, endpoint=False)
931+
zCoord = NodesGenerator("CHEBY-1", "GAUSS").getNodes(nZ)
932+
header = (4, [xCoord, zCoord])
933+
else:
934+
raise NotImplementedError(f"dim={self.dim}")
935+
936+
output = Rectilinear(np.float64, fileName)
937+
output.setHeader(*header)
938+
output.initialize()
939+
output.addField(0, uAll)
940+
941+
942+
def rbc3dInterpolation(coarseFields):
943+
"""
944+
Interpolate a RBC 3D field to twice its space resolution
945+
946+
Parameters
947+
----------
948+
coarseFields : np.4darray
949+
The fields values on the coarse grid, with shape [nV,nX,nY,nZ].
950+
The last dimension (z) uses a chebychev grid, while x and y are
951+
uniform periodic.
952+
953+
Returns
954+
-------
955+
fineFields : np.4darray
956+
The interpolated fields, with shape [nV,2*nX,2*nY,2*nZ]
957+
"""
958+
coarseFields = np.asarray(coarseFields)
959+
assert coarseFields.ndim == 4, "requires 4D array"
960+
961+
nV, nX, nY, nZ = coarseFields.shape
962+
963+
# Chebychev grids and interpolation matrix for z
964+
zC = NodesGenerator("CHEBY-1", "GAUSS").getNodes(nZ)
965+
zF = NodesGenerator("CHEBY-1", "GAUSS").getNodes(2*nZ)
966+
Pz = LagrangeApproximation(zC, weightComputation="STABLE").getInterpolationMatrix(zF)
967+
968+
# Fourier interpolation in x and y
969+
print(" -- computing 2D FFT ...")
970+
uFFT = np.fft.fftshift(np.fft.fft2(coarseFields, axes=(1, 2)), axes=(1, 2))
971+
print(" -- padding in Fourier space ...")
972+
uPadded = np.zeros_like(uFFT, shape=(nV, 2*nX, 2*nY, nZ))
973+
uPadded[:, nX//2:-nX//2, nY//2:-nY//2] = uFFT
974+
print(" -- computing 2D IFFT ...")
975+
uXY = np.fft.ifft2(np.fft.ifftshift(uPadded, axes=(1, 2)), axes=(1, 2)).real*4
976+
977+
# Polynomial interpolation in z
978+
print(" -- interpolating in z direction ...")
979+
fineFields = (Pz @ uXY.reshape(-1, nZ).T).T.reshape(nV, 2*nX, 2*nY, 2*nZ)
980+
981+
return fineFields
982+
983+
984+
def rbc2dInterpolation(coarseFields):
985+
"""
986+
Interpolate a RBC 2D field to twice its space resolution
987+
988+
Parameters
989+
----------
990+
coarseFields : np.3darray
991+
The fields values on the coarse grid, with shape [nV,nX,nZ].
992+
The last dimension (z) uses a chebychev grid,
993+
while x is uniform periodic.
994+
995+
Returns
996+
-------
997+
fineFields : np.4darray
998+
The interpolated fields, with shape [nV,2*nX,2*nZ]
999+
"""
1000+
coarseFields = np.asarray(coarseFields)
1001+
assert coarseFields.ndim == 3, "requires 3D array"
1002+
1003+
nV, nX, nZ = coarseFields.shape
1004+
1005+
# Chebychev grids and interpolation matrix for z
1006+
zC = NodesGenerator("CHEBY-1", "GAUSS").getNodes(nZ)
1007+
zF = NodesGenerator("CHEBY-1", "GAUSS").getNodes(2*nZ)
1008+
Pz = LagrangeApproximation(zC, weightComputation="STABLE").getInterpolationMatrix(zF)
1009+
1010+
# Fourier interpolation in x
1011+
print(" -- computing 1D FFT ...")
1012+
uFFT = np.fft.fftshift(np.fft.fft(coarseFields, axis=1), axes=1)
1013+
print(" -- padding in Fourier space ...")
1014+
uPadded = np.zeros_like(uFFT, shape=(nV, 2*nX, nZ))
1015+
uPadded[:, nX//2:-nX//2] = uFFT
1016+
print(" -- computing 1D IFFT ...")
1017+
uXY = np.fft.ifft(np.fft.ifftshift(uPadded, axes=1), axis=1).real*2
1018+
1019+
# Polynomial interpolation in z
1020+
print(" -- interpolating in z direction ...")
1021+
fineFields = (Pz @ uXY.reshape(-1, nZ).T).T.reshape(nV, 2*nX, 2*nZ)
1022+
1023+
return fineFields
1024+
8921025
def checkDNS(spectrum:np.ndarray, kappa:np.ndarray, sRatio:int=4, nThrow:int=0):
8931026
r"""
8941027
Check for a well-resolved DNS, by looking at an energy spectrum
@@ -967,21 +1100,21 @@ def fun(coeffs):
9671100
import matplotlib.pyplot as plt
9681101

9691102
# dirName = "run_3D_A4_M0.5_R1_Ra1e6"
970-
dirName = "run_3D_A4_M1_R1_Ra1e6"
1103+
dirName = "run_3D_A4_M1_R1_Ra2e5"
9711104
# dirName = "run_M4_R2"
9721105
# dirName = "test_M4_R2"
9731106
OutputFiles.VERBOSE = True
9741107
output = OutputFiles(dirName)
9751108

976-
if False:
1109+
if True:
9771110
series = output.getTimeSeries(which=["NuV", "NuT", "NuB"])
9781111

9791112
plt.figure("series")
9801113
for name, values in series.items():
9811114
plt.plot(output.times, values, label=name)
9821115
plt.legend()
9831116

984-
start = 60
1117+
start = 20
9851118

9861119
if False:
9871120
which = ["bRMS"]

0 commit comments

Comments
 (0)