Skip to content

Commit 7cd3928

Browse files
committed
plot helper modifications
1 parent 27f12d2 commit 7cd3928

File tree

2 files changed

+71
-41
lines changed

2 files changed

+71
-41
lines changed

pySDC/helpers/plot_helper.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import numpy as np
66

77

8-
def figsize(textwidth, scale):
8+
def figsize(textwidth, scale, ratio):
99
fig_width_pt = textwidth # Get this from LaTeX using \the\textwidth
1010
inches_per_pt = 1.0 / 72.27 # Convert pt to inch
11-
golden_mean = (np.sqrt(5.0) - 1.0) / 2.0 # Aesthetic ratio (you could change this)
1211
fig_width = fig_width_pt * inches_per_pt * scale # width in inches
13-
fig_height = fig_width * golden_mean # height in inches
12+
fig_height = fig_width * ratio # height in inches
1413
fig_size = [fig_width, fig_height]
1514
return fig_size
1615

@@ -51,10 +50,9 @@ def setup_mpl(font_size=8):
5150
mpl.rcParams.update(pgf_with_latex)
5251

5352

54-
def newfig(textwidth, scale):
53+
def newfig(textwidth, scale, ratio=0.6180339887):
5554
plt.clf()
56-
fig = plt.figure(figsize=figsize(textwidth, scale))
57-
ax = fig.add_subplot()
55+
fig, ax = plt.subplots(figsize=figsize(textwidth, scale, ratio))
5856
return fig, ax
5957

6058

pySDC/projects/TOMS/AllenCahn_contracting_circle.py

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pySDC.helpers.plot_helper as plt_helper
2+
import matplotlib.ticker as ticker
23

34
import dill
45
import os
@@ -51,7 +52,7 @@ def setup_parameters():
5152
problem_params = dict()
5253
problem_params['nu'] = 2
5354
problem_params['nvars'] = [(128, 128)]
54-
problem_params['eps'] = [0.0390625]
55+
problem_params['eps'] = [0.04]
5556
problem_params['newton_maxiter'] = 100
5657
problem_params['newton_tol'] = 1E-09
5758
problem_params['lin_tol'] = 1E-10
@@ -110,13 +111,25 @@ def run_SDC_variant(variant=None, inexact=False):
110111
description['sweeper_class'] = imex_1st_order
111112
if inexact:
112113
description['problem_params']['lin_maxiter'] = 10
114+
elif variant == 'semi-implicit_v2':
115+
description['problem_class'] = allencahn_semiimplicit_v2
116+
description['dtype_f'] = rhs_imex_mesh
117+
description['sweeper_class'] = imex_1st_order
118+
if inexact:
119+
description['problem_params']['newton_maxiter'] = 1
113120
elif variant == 'multi-implicit':
114121
description['problem_class'] = allencahn_multiimplicit
115122
description['dtype_f'] = rhs_comp2_mesh
116123
description['sweeper_class'] = multi_implicit
117124
if inexact:
118125
description['problem_params']['newton_maxiter'] = 1
119126
description['problem_params']['lin_maxiter'] = 10
127+
elif variant == 'multi-implicit_v2':
128+
description['problem_class'] = allencahn_multiimplicit_v2
129+
description['dtype_f'] = rhs_comp2_mesh
130+
description['sweeper_class'] = multi_implicit
131+
if inexact:
132+
description['problem_params']['newton_maxiter'] = 1
120133
else:
121134
raise NotImplemented('Wrong variant specified, got %s' % variant)
122135

@@ -188,23 +201,38 @@ def show_results(fname, cwd=''):
188201
plt_helper.setup_mpl()
189202

190203
# set up plot for timings
191-
plt_helper.newfig(textwidth=238.96, scale=1.0)
204+
fig, ax1 = plt_helper.newfig(textwidth=238.96, scale=1.5, ratio=0.4)
192205

193206
timings = {}
207+
niters = {}
194208
for key, item in results.items():
195209
timings[key] = sort_stats(filter_stats(item, type='timing_run'), sortby='time')[0][1]
210+
iter_counts = sort_stats(filter_stats(item, type='niter'), sortby='time')
211+
niters[key] = np.mean(np.array([item[1] for item in iter_counts]))
196212

197213
xcoords = [i for i in range(len(timings))]
198-
sorted_data = sorted([(key, timings[key]) for key in timings], reverse=True, key=lambda tup: tup[1])
199-
heights = [item[1] for item in sorted_data]
200-
keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n') for item in sorted_data]
214+
sorted_timings = sorted([(key, timings[key]) for key in timings], reverse=True, key=lambda tup: tup[1])
215+
sorted_niters = [(k, niters[k]) for k in [key[0] for key in sorted_timings]]
216+
heights_timings = [item[1] for item in sorted_timings]
217+
heights_niters = [item[1] for item in sorted_niters]
218+
keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n').replace('_v2', ' mod.') for item in sorted_timings]
219+
220+
ax1.bar(xcoords, heights_timings, align='edge', width=-0.3, label='timings (left axis)')
221+
ax1.set_ylabel('time (sec)')
201222

202-
plt_helper.plt.bar(xcoords, heights, align='center')
223+
ax2 = ax1.twinx()
224+
ax2.bar(xcoords, heights_niters, color='r', align='edge', width=0.3, label='iterations (right axis)')
225+
ax2.set_ylabel('mean number of iterations')
203226

204-
plt_helper.plt.xticks(xcoords, keys, rotation=90)
205-
plt_helper.plt.ylabel('time (sec)')
227+
ax1.set_xticks(xcoords)
228+
ax1.set_xticklabels(keys, rotation=90, ha='center')
206229

207-
# # save plot, beautify
230+
# ask matplotlib for the plotted objects and their labels
231+
lines, labels = ax1.get_legend_handles_labels()
232+
lines2, labels2 = ax2.get_legend_handles_labels()
233+
ax2.legend(lines + lines2, labels + labels2, loc=0)
234+
235+
# save plot, beautify
208236
f = fname + '_timings'
209237
plt_helper.savefig(f)
210238

@@ -213,15 +241,16 @@ def show_results(fname, cwd=''):
213241
assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file'
214242

215243
# set up plot for radii
216-
plt_helper.newfig(textwidth=238.96, scale=1.0)
244+
fig, ax = plt_helper.newfig(textwidth=238.96, scale=1.0)
217245

218246
exact_radii = []
219247
for key, item in results.items():
220248
computed_radii = sort_stats(filter_stats(item, type='computed_radius'), sortby='time')
221249

222250
xcoords = [item0[0] for item0 in computed_radii]
223251
radii = [item0[1] for item0 in computed_radii]
224-
plt_helper.plt.plot(xcoords, radii, label=key[0] + ' ' + key[1])
252+
if key[0] + ' ' + key[1] == 'fully-implicit exact':
253+
ax.plot(xcoords, radii, label=(key[0] + ' ' + key[1]).replace('_v2', ' mod.'))
225254

226255
exact_radii = sort_stats(filter_stats(item, type='exact_radius'), sortby='time')
227256

@@ -233,12 +262,13 @@ def show_results(fname, cwd=''):
233262

234263
xcoords = [item[0] for item in exact_radii]
235264
radii = [item[1] for item in exact_radii]
236-
plt_helper.plt.plot(xcoords, radii, color='k', linestyle='--', linewidth=1, label='exact')
265+
ax.plot(xcoords, radii, color='k', linestyle='--', linewidth=1, label='exact')
237266

238-
plt_helper.plt.ylabel('radius')
239-
plt_helper.plt.xlabel('time')
240-
plt_helper.plt.grid()
241-
plt_helper.plt.legend()
267+
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.2f'))
268+
ax.set_ylabel('radius')
269+
ax.set_xlabel('time')
270+
ax.grid()
271+
ax.legend(loc=3)
242272

243273
# save plot, beautify
244274
f = fname + '_radii'
@@ -249,23 +279,25 @@ def show_results(fname, cwd=''):
249279
assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file'
250280

251281
# set up plot for interface width
252-
plt_helper.newfig(textwidth=238.96, scale=1.0)
282+
fig, ax = plt_helper.newfig(textwidth=238.96, scale=1.0)
253283

254284
interface_width = []
255285
for key, item in results.items():
256286
interface_width = sort_stats(filter_stats(item, type='interface_width'), sortby='time')
257287
xcoords = [item[0] for item in interface_width]
258288
width = [item[1] for item in interface_width]
259-
plt_helper.plt.plot(xcoords, width, label=key[0] + ' ' + key[1])
289+
if key[0] + ' ' + key[1] == 'fully-implicit exact':
290+
ax.plot(xcoords, width, label=key[0] + ' ' + key[1])
260291

261292
xcoords = [item[0] for item in interface_width]
262293
init_width = [interface_width[0][1]] * len(xcoords)
263-
plt_helper.plt.plot(xcoords, init_width, color='k', linestyle='--', linewidth=1, label='exact')
294+
ax.plot(xcoords, init_width, color='k', linestyle='--', linewidth=1, label='exact')
264295

265-
plt_helper.plt.ylabel('interface width')
266-
plt_helper.plt.xlabel('time')
267-
plt_helper.plt.grid()
268-
plt_helper.plt.legend()
296+
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.2f'))
297+
ax.set_ylabel('interface')
298+
ax.set_xlabel('time')
299+
ax.grid()
300+
ax.legend(loc=3)
269301

270302
# save plot, beautify
271303
f = fname + '_interface'
@@ -286,19 +318,19 @@ def main(cwd=''):
286318
cwd (str): current working directory (need this for testing)
287319
"""
288320

289-
# Loop over variants, exact and inexact solves
290-
results = {}
291-
for variant in ['multi-implicit', 'semi-implicit', 'fully-implicit']:
292-
293-
results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False)
294-
results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True)
295-
296-
# dump result
321+
# # Loop over variants, exact and inexact solves
322+
# results = {}
323+
# for variant in ['multi-implicit', 'semi-implicit', 'fully-implicit', 'semi-implicit_v2', 'multi-implicit_v2']:
324+
#
325+
# results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False)
326+
# results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True)
327+
#
328+
# # dump result
297329
fname = 'data/results_SDC_variants_AllenCahn_1E-03'
298-
file = open(cwd + fname + '.pkl', 'wb')
299-
dill.dump(results, file)
300-
file.close()
301-
assert os.path.isfile(cwd + fname + '.pkl'), 'ERROR: dill did not create file'
330+
# file = open(cwd + fname + '.pkl', 'wb')
331+
# dill.dump(results, file)
332+
# file.close()
333+
# assert os.path.isfile(cwd + fname + '.pkl'), 'ERROR: dill did not create file'
302334

303335
# visualize
304336
show_results(fname, cwd=cwd)

0 commit comments

Comments
 (0)