Skip to content
This repository was archived by the owner on Feb 1, 2024. It is now read-only.

Commit d375472

Browse files
committed
update hofx plots for ioda v2
1 parent 1fa1c44 commit d375472

File tree

3 files changed

+169
-137
lines changed

3 files changed

+169
-137
lines changed

src/fv3jeditools/diag_da_convergence.py

Lines changed: 94 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -43,84 +43,6 @@ def da_convergence(datetime, conf):
4343
except:
4444
output_path = './'
4545

46-
# Create output path
47-
if not os.path.exists(output_path):
48-
os.makedirs(output_path)
49-
50-
# Replace datetime in logfile name
51-
isodatestr = datetime.strftime("%Y-%m-%dT%H:%M:%S")
52-
log_file = utils.stringReplaceDatetimeTemplate(isodatestr, log_file)
53-
54-
55-
# Read file and gather norm information
56-
print(" Reading convergence from ", log_file)
57-
58-
59-
# Open the file ready for reading
60-
if os.path.exists(log_file):
61-
file = open(log_file, "r")
62-
else:
63-
utils.abort('Log file not found.')
64-
65-
66-
# Search for the type of minimizer used for the assimilation
67-
for line in file:
68-
if "Minimizer algorithm=" in line:
69-
minimizer = line.split('=')[1].rstrip()
70-
break
71-
72-
73-
# Patterns to search for from the file
74-
search_patterns = []
75-
search_patterns.append(" Norm reduction .")
76-
search_patterns.append(" Quadratic cost function: J .")
77-
search_patterns.append(" Quadratic cost function: Jb .")
78-
search_patterns.append(" Quadratic cost function: JoJc.")
79-
search_patterns.append("GMRESR end of iteration .")
80-
81-
# Labels for the figures
82-
ylabels = []
83-
ylabels.append(minimizer+" normalized gradient reduction")
84-
ylabels.append("Quadratic cost function J ")
85-
ylabels.append("Quadratic cost function Jb ")
86-
ylabels.append("Quadratic cost function JoJc")
87-
ylabels.append("GMRESR norm reduction")
88-
89-
# Get all lines that match the search patterns
90-
matches = []
91-
for line in file:
92-
for search_pattern in search_patterns:
93-
reg = re.compile(search_pattern)
94-
if bool(re.match(reg, line.rstrip())):
95-
matches.append(line.rstrip())
96-
97-
# Close the file
98-
file.close()
99-
100-
# Loop over stats to be searched on
101-
maxiterations = 10000
102-
count = np.zeros(len(search_patterns), dtype=int)
103-
stats = np.zeros((len(search_patterns), maxiterations))
104-
for search_pattern in search_patterns:
105-
106-
index = [i for i, s in enumerate(search_patterns) if search_pattern in s]
107-
108-
# Loop over the matches and fill stats
109-
for match in matches:
110-
111-
reg = re.compile(search_pattern)
112-
if bool(re.match(reg, match)):
113-
114-
stats[index,count[index]] = match.split()[-1]
115-
count[index] = count[index] + 1
116-
117-
niter = count[0]
118-
stat = np.zeros(niter)
119-
120-
121-
# Create figures
122-
# --------------
123-
12446
# Scale for y-axis
12547
try:
12648
yscale = conf['yscale']
@@ -133,28 +55,104 @@ def da_convergence(datetime, conf):
13355
except:
13456
plotformat = 'png'
13557

136-
for ylabel in ylabels:
58+
# Create output path
59+
if not os.path.exists(output_path):
60+
os.makedirs(output_path)
61+
62+
# Replace datetime in logfile name
63+
isodatestr = datetime.strftime("%Y-%m-%dT%H:%M:%S")
64+
log_file = utils.stringReplaceDatetimeTemplate(isodatestr, log_file)
65+
13766

138-
index = [i for i, s in enumerate(ylabels) if ylabel in s]
139-
savename = ylabel.lower().strip()
140-
savename = savename.replace(" ", "-")
141-
savename = savename+"_"+datetime.strftime("%Y%m%d_%H%M%S")+"."+plotformat
142-
savename = os.path.join(output_path,savename)
67+
# Read file and gather norm information
68+
print(" Reading convergence from ", log_file)
14369

144-
stat[0:niter] = stats[index,0:niter]
145-
stat_plot = stat[np.nonzero(stat)]
14670

147-
iter = np.arange(1, len(stat_plot)+1)
71+
# Check the file exists
72+
if not os.path.exists(log_file):
73+
utils.abort('Log file not found.')
14874

149-
fig, ax = plt.subplots(figsize=(15, 7.5))
150-
ax.plot(iter, stat_plot, linestyle='-', marker='x')
151-
ax.tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=True)
152-
plt.title("JEDI variational assimilation convergence statistics | "+isodatestr)
153-
plt.xlabel("Iteration number")
154-
plt.ylabel(ylabel)
155-
plt.yscale(yscale)
75+
# Convert file to list
76+
lines = []
77+
with open(log_file) as file:
78+
for line in file:
79+
lines.append(line)
80+
81+
# Get unique minizers used in run, e.g. DRIPCG + GMRES
82+
minimizers = []
83+
for line in lines:
84+
if " end of iteration " in line:
85+
minimizers.append(line.split()[0])
86+
minimizers=set(minimizers) # Unique only
87+
88+
89+
# Loop over minimizers
90+
for minimizer in minimizers:
91+
92+
print('Processing ', minimizer)
93+
94+
grad_red = []
95+
norm_red = []
96+
quad_j = []
97+
quad_jb = []
98+
quad_JoJc = []
99+
for num, line in enumerate(lines, 1):
100+
if minimizer+" end of iteration " in line:
101+
grad_red.append(lines[num].split()[-1])
102+
norm_red.append(lines[num+1].split()[-1])
103+
if (lines[num+3].split()[0] == 'Quadratic'):
104+
quad_j.append(lines[num+3].split()[-1])
105+
if (lines[num+3].split()[0] == 'Quadratic'):
106+
quad_jb.append(lines[num+4].split()[-1])
107+
if (lines[num+3].split()[0] == 'Quadratic'):
108+
quad_JoJc.append(lines[num+5].split()[-1])
109+
110+
111+
# Loop over metrics
112+
for s in range(5):
113+
114+
if (s==0):
115+
stat_str = grad_red
116+
ylabel = 'Gradient reduction'
117+
elif (s==1):
118+
stat_str = norm_red
119+
ylabel = 'Norm reduction'
120+
elif (s==2):
121+
stat_str = quad_j
122+
ylabel = 'Quadratic cost function: J'
123+
elif (s==3):
124+
stat_str = quad_jb
125+
ylabel = 'Quadratic cost function: Jb'
126+
elif (s==4):
127+
stat_str = quad_JoJc
128+
ylabel = 'Quadratic cost function: JoJc'
129+
130+
niter = len(stat_str)
131+
132+
if niter > 1:
133+
134+
stat = np.zeros(len(stat_str))
135+
stat[0:niter] = stat_str
136+
137+
stat_plot = stat[np.nonzero(stat)]
138+
iter = np.arange(1, len(stat_plot)+1)
139+
140+
savename = minimizer.lower()+"-"+ylabel.lower().strip()
141+
savename = savename.replace(":", "")
142+
savename = savename.replace(" ", "-")
143+
savename = savename+"_"+datetime.strftime("%Y%m%d_%H%M%S")+"."+plotformat
144+
savename = os.path.join(output_path,savename)
145+
146+
fig, ax = plt.subplots(figsize=(15, 7.5))
147+
ax.plot(iter, stat_plot, linestyle='-', marker='x')
148+
ax.tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=True)
149+
plt.title("JEDI variational assimilation convergence statistics | "+isodatestr)
150+
plt.xlabel("Iteration number")
151+
plt.ylabel(ylabel)
152+
plt.yscale(yscale)
153+
plt.xlim([0.9, niter+0.1])
154+
print(" Saving figure as", savename, "\n")
155+
plt.savefig(savename)
156156

157-
print(" Saving figure as", savename, "\n")
158-
plt.savefig(savename)
159157

160158
# --------------------------------------------------------------------------------------------------

src/fv3jeditools/diag_hofx_innovations.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def hofx_innovations(datetime, conf):
5555

5656

5757
# Get variable to plot
58-
variable = utils.configGetOrFail(conf, 'variable')
58+
variable = utils.configGetOrFail(conf, 'field')
5959

6060

6161
# Get number of outer loops used in the assimilation
@@ -117,11 +117,6 @@ def hofx_innovations(datetime, conf):
117117
vmetric = 'innovations'
118118

119119

120-
# Figure filename
121-
# ---------------
122-
savename = os.path.join(output_path, varname+"_"+vmetric+"_"+datetime.strftime("%Y%m%d_%H%M%S")+"."+plotformat)
123-
124-
125120
# Compute window begin time
126121
# -------------------------
127122
window_begin = datetime + time_offset - window_length/2
@@ -155,26 +150,52 @@ def hofx_innovations(datetime, conf):
155150
# Open file for reading
156151
fh = netCDF4.Dataset(hofx_file)
157152

153+
# Check for channels
154+
try:
155+
nchans = fh.dimensions["nchans"].size
156+
except:
157+
nchans = 0
158+
159+
# User must provide channel number
160+
if nchans != 0:
161+
chan = utils.configGetOrFail(conf, 'channel')
162+
158163
# Number of locations in this file
159164
nlocs_final = nlocs_start + fh.dimensions['nlocs'].size
160165

161166
# Background
162-
obs[nlocs_start:nlocs_final] = fh.variables[variable+'@ObsValue'][:]
167+
if nchans != 0:
168+
obs[nlocs_start:nlocs_final] = fh.groups['ObsValue'].variables[variable][:,chan-1]
169+
else:
170+
obs[nlocs_start:nlocs_final] = fh.groups['ObsValue'].variables[variable][:]
171+
163172

164173
# Set missing values to nans
165174
obs[nlocs_start:nlocs_final] = np.where(np.abs(obs[nlocs_start:nlocs_final]) < missing,
166175
obs[nlocs_start:nlocs_final], float("NaN"))
167176

168177
# Loop over outer loops
169178
for n in range(nouter+1):
170-
hofx[nlocs_start:nlocs_final,n] = fh.variables[variable+'@hofx'+str(n)][:] - \
171-
obs[nlocs_start:nlocs_final]
179+
if nchans != 0:
180+
hofx[nlocs_start:nlocs_final,n] = fh.groups['hofx'+str(n)].variables[variable][:,chan-1] - \
181+
obs[nlocs_start:nlocs_final]
182+
else:
183+
hofx[nlocs_start:nlocs_final,n] = fh.groups['hofx'+str(n)].variables[variable][:] - \
184+
obs[nlocs_start:nlocs_final]
172185

173186
# Set start ready for next file
174187
nlocs_start = nlocs_final
175188

176189
fh.close()
177190

191+
# Figure filename
192+
# ---------------
193+
if nchans != 0:
194+
savename = os.path.join(output_path, varname+"-channel"+str(chan)+"_"+vmetric+"_"+datetime.strftime("%Y%m%d_%H%M%S")+"."+plotformat)
195+
else:
196+
savename = os.path.join(output_path, varname+"_"+vmetric+"_"+datetime.strftime("%Y%m%d_%H%M%S")+"."+plotformat)
197+
198+
178199
# Statistics arrays
179200
hist = np.zeros((nbins, nouter+1))
180201
edges = np.zeros((nbins, nouter+1))

0 commit comments

Comments
 (0)