Skip to content

Commit f6bb9f1

Browse files
committed
TL: postData cache + new scripts
1 parent dd6ff34 commit f6bb9f1

File tree

4 files changed

+269
-40
lines changed

4 files changed

+269
-40
lines changed

pySDC/playgrounds/dedalus/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ demos/*.png
44
problems/test_*
55
problems/run_*
66
problems/runInit_*
7+
problems/initRun_*
78

89
scripts/*.sh
910
scripts/run_*
1011
scripts/init_*
1112
scripts/runInit_*
13+
scripts/initRun_*

pySDC/playgrounds/dedalus/problems/rbc.py

Lines changed: 154 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import os
77
import socket
8+
import json
89
import glob
910
from datetime import datetime
1011
from time import sleep
@@ -26,6 +27,13 @@
2627
MPI_RANK = COMM_WORLD.Get_rank()
2728

2829

30+
class NumpyEncoder(json.JSONEncoder):
31+
def default(self, obj):
32+
if isinstance(obj, np.ndarray):
33+
return obj.tolist()
34+
return json.JSONEncoder.default(self, obj)
35+
36+
2937
class RBCProblem2D():
3038

3139
BASE_RESOLUTION = 64
@@ -442,6 +450,24 @@ def __init__(self, folder):
442450
self.Ra = float(lines[0].split(" : ")[-1])
443451
self.Pr = float(lines[1].split(" : ")[-1])
444452

453+
# prepare postData file
454+
infosFile = f"{self.folder}/00_infoSimu.txt"
455+
if os.path.isfile(infosFile):
456+
with open(infosFile, "r") as f:
457+
lines = f.readlines()
458+
data = {key: val for key, val in [l.split(" : ") for l in lines]}
459+
self.setPostData("infos", "Ra", float(data["Ra"]))
460+
self.setPostData("infos", "Pr", float(data["Pr"]))
461+
self.setPostData("infos", "Lx", int(data["Lx"]))
462+
self.setPostData("infos", "Ly", int(data["Ly"]))
463+
self.setPostData("infos", "Lz", int(data["Lz"]))
464+
self.setPostData("infos", "Nx", int(data["Nx"]))
465+
self.setPostData("infos", "Ny", int(data["Ny"]))
466+
self.setPostData("infos", "Nz", int(data["Nz"]))
467+
self.setPostData("infos", "tEnd", float(data["tEnd"]))
468+
self.setPostData("infos", "dt", float(data["dt"]))
469+
self.setPostData("infos", "nSteps", int(data["nSteps"]))
470+
445471
def __del__(self):
446472
try:
447473
self.file.close()
@@ -560,16 +586,54 @@ def readFields(self, name, start=0, stop=None, step=1):
560586
if self.VERBOSE: print(" -- done !")
561587
return data
562588

589+
@property
590+
def postFile(self):
591+
return f"{self.folder}/postData.json"
563592

564-
def getTimeSeries(self, which=["ke"], batchSize=None):
593+
@property
594+
def postData(self):
595+
try:
596+
with open(self.postFile, "r") as f:
597+
data = json.load(f)
598+
except:
599+
data = {}
600+
return data
565601

602+
def setPostData(self, *args):
603+
data = self.postData
604+
dico = data
605+
assert len(args) > 1, "requires at least two arguments for setPostData"
606+
*keys, value = args
607+
for key in keys[:-1]:
608+
if key not in dico:
609+
dico[key] = {}
610+
dico = dico[key]
611+
if isinstance(value, np.ndarray):
612+
value = value.tolist()
613+
dico[keys[-1]] = value
614+
with open(self.postFile, "w") as f:
615+
json.dump(data, f, cls=NumpyEncoder)
616+
617+
618+
def getTimeSeries(self, which=["ke"], batchSize=None):
566619

567620
if which == "all":
568621
which = ["ke", "keH", "keV", "NuV", "NuT", "NuB"]
569622
else:
570623
which = list(which)
571624

572-
series = {name: [] for name in which}
625+
data = self.postData.get("series", {})
626+
series = {
627+
name: data.get(name, []) for name in ["times"] + which
628+
}
629+
series["times"] = self.times
630+
631+
which = [name for name in which if name not in data]
632+
if len(which) == 0:
633+
for key, val in series.items():
634+
series[key] = np.array(val)
635+
return series
636+
573637
avgAxes = 1 if self.dim==2 else (1, 2)
574638

575639
approx = LagrangeApproximation(self.z)
@@ -622,6 +686,9 @@ def getTimeSeries(self, which=["ke"], batchSize=None):
622686
for key, val in series.items():
623687
series[key] = np.array(val).ravel()
624688

689+
# Save in postData
690+
self.setPostData("series", series)
691+
625692
return series
626693

627694

@@ -643,11 +710,27 @@ def getProfiles(self, which=["uRMS", "bRMS"],
643710
+ [var+"RMS" for var in formula.keys()]
644711
else:
645712
which = list(which)
713+
646714
if "bRMS" in which and "bMean" not in which:
647715
which.append("bMean")
648716
if "pRMS" in which and "pMean" not in which:
649717
which.append("pMean")
650-
profiles = {name: np.zeros(self.nZ) for name in which}
718+
719+
data = self.postData.get("profiles", {})
720+
if data and tuple(data.get("slice", [])) != (start, stop, step):
721+
print(f"WARNING : ovewriting profiles data with slice={(start, stop, step)}")
722+
data = {}
723+
profiles = {
724+
name: np.array(data.get(name, np.zeros(self.nZ)))
725+
for name in ["slice", "nSamples", "zVals"] + which
726+
}
727+
profiles["slice"] = (start, stop, step)
728+
profiles["nSamples"] = len(range(start, stop, step))
729+
profiles["zVals"] = self.z
730+
731+
which = [name for name in which if name not in data]
732+
if len(which) == 0:
733+
return profiles
651734

652735
nSamples = 0
653736
def addSamples(current, new, nNew):
@@ -714,7 +797,10 @@ def addSamples(current, new, nNew):
714797
if "RMS" in name:
715798
val **= 0.5
716799

717-
profiles["nSamples"] = nSamples
800+
# Save in postData
801+
assert profiles["nSamples"] == nSamples, "very weird error ..."
802+
self.setPostData("profiles", profiles)
803+
718804
return profiles
719805

720806

@@ -741,17 +827,33 @@ def getBoundaryLayers(self, which=["uRMS", "bRMS"], profiles=None,
741827

742828
return deltas
743829

744-
def getSpectrum(self, which=["uV", "uH"], zVal="all",
830+
831+
def getSpectrum(self, which=["uv", "uh"], zVal="all",
745832
start=0, stop=None, step=1, batchSize=None):
746833
if stop is None:
747834
stop = self.nFields
748835
if which == "all":
749-
which = ["u", "uV", "uH", "b", "p"]
836+
which = ["u", "uv", "uh", "b", "p"]
750837
else:
751838
which = list(which)
752839

753840
kappa = self.kappa
754-
spectrum = {name: np.zeros(kappa.size) for name in which}
841+
842+
data = self.postData.get("spectrum", {})
843+
if data and tuple(data.get("slice", [])) != (start, stop, step):
844+
print(f"WARNING : ovewriting profiles data with slice={(start, stop, step)}")
845+
data = {}
846+
spectrum = {
847+
name: np.array(data.get(name, np.zeros(kappa.size)))
848+
for name in ["slice", "nSamples", "kappa"] + which
849+
}
850+
spectrum["slice"] = (start, stop, step)
851+
spectrum["nSamples"] = len(range(start, stop, step))
852+
spectrum["kappa"] = kappa
853+
854+
which = [name for name in which if name not in data]
855+
if len(which) == 0:
856+
return spectrum
755857

756858
approx = LagrangeApproximation(self.z, weightComputation="STABLE")
757859
if zVal == "all":
@@ -764,16 +866,16 @@ def getSpectrum(self, which=["uV", "uH"], zVal="all",
764866

765867
bSize = len(r)
766868

767-
if set(which).intersection(["u", "uV", "uH"]):
869+
if set(which).intersection(["u", "uv", "uh"]):
768870
u = self.readFields("velocity", r.start, r.stop, r.step)
769871

770872
for name in which:
771873

772874
# define fields with shape (nT,nComp,nX[,nY],nZ)
773-
if name in ["u", "uV", "uH"]:
774-
if name == "uV":
875+
if name in ["u", "uv", "uh"]:
876+
if name == "uv":
775877
field = u[:, -1:]
776-
if name == "uH":
878+
if name == "uh":
777879
field = u[:, :-1]
778880
if name == "u":
779881
field = u
@@ -786,7 +888,7 @@ def getSpectrum(self, which=["uV", "uH"], zVal="all",
786888
elif name == "p":
787889
field = self.readFields("pressure", r.start, r.stop, r.step)[:, None, ...]
788890
else:
789-
raise NotImplementedError(f"{name} in which ...")
891+
raise NotImplementedError(f"spectrum computation for {name}")
790892

791893
if zVal != "all":
792894
field = (mPz @ field[..., None])[..., 0]
@@ -872,6 +974,10 @@ def getSpectrum(self, which=["uV", "uH"], zVal="all",
872974

873975
nSamples += bSize
874976

977+
# Save in postData
978+
assert spectrum["nSamples"] == nSamples, "very weird error ..."
979+
self.setPostData("spectrum", spectrum)
980+
875981
return spectrum
876982

877983

@@ -1100,15 +1206,15 @@ def checkDNS(spectrum:np.ndarray, kappa:np.ndarray, sRatio:int=4, nThrow:int=0):
11001206
x = np.log(kTail)
11011207

11021208
def fun(coeffs):
1103-
a, b, c = coeffs
1104-
return np.linalg.norm(y - a*x**2 - b*x - c)
1209+
c2, c1, c0 = coeffs
1210+
return np.linalg.norm(y - c2*x**2 - c1*x - c0)
11051211

11061212
res = sco.minimize(fun, [0, 0, 0])
1107-
a, b, c = res.x
1213+
c2, c1, c0 = [float(c) for c in res.x]
11081214

11091215
results = {
1110-
"DNS": not a > 0,
1111-
"coeffs": (a, b, c),
1216+
"DNS": not c2 > 0,
1217+
"coeffs": (c2, c1, c0),
11121218
"kTail": kTail,
11131219
"sTail": sTail,
11141220
}
@@ -1119,24 +1225,23 @@ def fun(coeffs):
11191225
import matplotlib.pyplot as plt
11201226

11211227
# dirName = "run_3D_A4_M0.5_R1_Ra1e6"
1122-
dirName = "run_3D_A4_M0.5_R1_Ra1e5"
1228+
dirName = "run_3D_A4_M0.5_R1_Ra5e3"
11231229
# dirName = "run_M4_R2"
11241230
# dirName = "test_M4_R2"
11251231
OutputFiles.VERBOSE = True
11261232
output = OutputFiles(dirName)
11271233

1128-
if True:
1234+
if False:
11291235
series = output.getTimeSeries(which=["ke", "keH", "keV", "NuV"])
11301236

11311237
plt.figure("series")
1132-
for name, values in series.items():
1133-
plt.plot(output.times, values, label=name)
1238+
plt.plot(output.times, series["NuV"], label=dirName)
11341239
plt.legend()
11351240

1136-
start = 60
1241+
start = 20
11371242

1138-
if True:
1139-
which = ["bRMS"]
1243+
if False:
1244+
which = ["bRMS", "uRMS", "uMean"]
11401245

11411246
Nu = series["NuV"][start:].mean()
11421247

@@ -1161,7 +1266,7 @@ def fun(coeffs):
11611266
plt.xlabel("profile")
11621267
plt.ylabel("z coord")
11631268

1164-
zLog = np.logspace(np.log10(1/(100*Nu)), np.log(0.5), num=200)
1269+
zLog = np.logspace(np.log10(1/(100*Nu)), np.log10(0.5), num=200)
11651270
approx = LagrangeApproximation(output.z)
11661271
mPz = approx.getInterpolationMatrix(zLog)
11671272

@@ -1171,33 +1276,42 @@ def fun(coeffs):
11711276
bRMS = (profiles["bRMS"] + profiles["bRMS"][-1::-1])/2
11721277
bRMS = mPz @ bRMS
11731278

1174-
plt.figure("mean-log")
1279+
plt.figure("bmean-log")
11751280
plt.semilogx(zLog*Nu, bMean, label=dirName)
11761281
plt.legend()
11771282

1178-
plt.figure("rms-log")
1283+
plt.figure("RMS-log")
11791284
plt.semilogx(zLog*Nu, bRMS, label=dirName)
11801285
plt.legend()
11811286

1287+
uRMS = (profiles["uRMS"] + profiles["uRMS"][-1::-1])/2
1288+
uRMS = mPz @ uRMS
1289+
1290+
plt.figure("RMS-log")
1291+
plt.semilogx(zLog*Nu, uRMS, label=dirName)
1292+
plt.legend()
1293+
11821294
if True:
11831295
spectrum = output.getSpectrum(
1184-
which=["uH"],
1185-
zVal="all", start=start, batchSize=None)
1186-
kappa = output.kappa
1187-
1188-
check = checkDNS(spectrum["uH"], kappa)
1189-
a, b, c = check["coeffs"]
1190-
print(f"DNS : {check['DNS']} ({a=})")
1191-
kTail = check["kTail"]
1192-
sTail = check["sTail"]
1296+
which="all", zVal="all",
1297+
start=start, batchSize=None)
11931298

1299+
kappa = output.kappa
11941300
plt.figure("spectrum")
1195-
for name, vals in spectrum.items():
1301+
for name in ["u", "uv", "uh", "b", "p"]:
1302+
vals = spectrum[name]
1303+
check = checkDNS(vals, kappa)
1304+
a, b, c = check["coeffs"]
1305+
c2 = float(a)
1306+
print(f"DNS (on {name}): {check['DNS']} ({c2=})")
1307+
kTail = check["kTail"]
1308+
sTail = check["sTail"]
1309+
11961310
plt.loglog(kappa[1:], vals[1:], label=name)
11971311

1198-
# plt.loglog(kTail, sTail, '.', c="black")
1199-
# kTL = np.log(kTail)
1200-
# plt.loglog(kTail, np.exp(a*kTL**2 + b*kTL + c), c="gray")
1312+
plt.loglog(kTail, sTail, '.', c="black")
1313+
kTL = np.log(kTail)
1314+
plt.loglog(kTail, np.exp(a*kTL**2 + b*kTL + c), c="gray")
12011315

12021316
plt.loglog(kappa[1:], kappa[1:]**(-5/3), '--k')
12031317
plt.text(10, 0.1, r"$\kappa^{-5/3}$", fontsize=16)

0 commit comments

Comments
 (0)