Skip to content

Commit 86a721c

Browse files
committed
AC example up and running
1 parent ce8694d commit 86a721c

File tree

2 files changed

+97
-30
lines changed

2 files changed

+97
-30
lines changed

pySDC/projects/TOMS/AllenCahn_contracting_circle.py

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def setup_parameters():
4444
sweeper_params['Q1'] = ['LU']
4545
sweeper_params['Q2'] = ['LU']
4646
sweeper_params['QI'] = ['LU']
47+
sweeper_params['QE'] = ['EE']
4748
sweeper_params['spread'] = False
4849

4950
# This comes as read-in for the problem class
@@ -101,27 +102,33 @@ def run_SDC_variant(variant=None, inexact=False):
101102
description['problem_class'] = allencahn_fullyimplicit
102103
description['dtype_f'] = mesh
103104
description['sweeper_class'] = generic_implicit
105+
if inexact:
106+
description['problem_params']['newton_maxiter'] = 1
104107
elif variant == 'semi-implicit':
105108
description['problem_class'] = allencahn_semiimplicit
106109
description['dtype_f'] = rhs_imex_mesh
107110
description['sweeper_class'] = imex_1st_order
111+
if inexact:
112+
description['problem_params']['lin_maxiter'] = 10
108113
elif variant == 'multi-implicit':
109114
description['problem_class'] = allencahn_multiimplicit
110115
description['dtype_f'] = rhs_comp2_mesh
111116
description['sweeper_class'] = multi_implicit
117+
if inexact:
118+
description['problem_params']['newton_maxiter'] = 1
119+
description['problem_params']['lin_maxiter'] = 10
112120
else:
113121
raise NotImplemented('Wrong variant specified, got %s' % variant)
114122

115123
if inexact:
116-
description['problem_params']['newton_maxiter'] = 1
117124
out = 'Working on inexact %s variant...' % variant
118125
else:
119126
out = 'Working on exact %s variant...' % variant
120127
print(out)
121128

122129
# setup parameters "in time"
123130
t0 = 0
124-
Tend = 0.001
131+
Tend = 0.032
125132

126133
# instantiate controller
127134
controller = allinclusive_multigrid_nonMPI(num_procs=1, controller_params=controller_params,
@@ -152,18 +159,16 @@ def run_SDC_variant(variant=None, inexact=False):
152159
out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters)))
153160
print(out)
154161

155-
print('Iteration count (nonlinear/linear): %i / %i' % (P.newton_itercount, P.lin_itercount))
156-
print('Mean Iteration count per call: %4.2f / %4.2f' % (P.newton_itercount / max(P.newton_ncalls, 1),
157-
P.lin_itercount / max(P.lin_ncalls, 1)))
162+
print(' Iteration count (nonlinear/linear): %i / %i' % (P.newton_itercount, P.lin_itercount))
163+
print(' Mean Iteration count per call: %4.2f / %4.2f' % (P.newton_itercount / max(P.newton_ncalls, 1),
164+
P.lin_itercount / max(P.lin_ncalls, 1)))
158165

159166
timing = sort_stats(filter_stats(stats, type='timing_run'), sortby='time')
160167

161168
print('Time to solution: %6.4f sec.' % timing[0][1])
162169
print()
163170

164-
# assert np.mean(niters) <= 23, 'ERROR: number of iterations is too high, got %s' % np.mean(niters)
165-
166-
return timing[0][1], np.mean(niters)
171+
return stats
167172

168173

169174
def show_results(fname):
@@ -181,10 +186,15 @@ def show_results(fname):
181186
# plt_helper.mpl.style.use('classic')
182187
plt_helper.setup_mpl()
183188

189+
# set up plot for timings
184190
plt_helper.newfig(textwidth=238.96, scale=1.0)
185191

186-
xcoords = [i for i in range(len(results))]
187-
sorted_data = sorted([(key, results[key][0]) for key in results], reverse=True, key=lambda tup: tup[1])
192+
timings = {}
193+
for key, item in results.items():
194+
timings[key] = sort_stats(filter_stats(item, type='timing_run'), sortby='time')[0][1]
195+
196+
xcoords = [i for i in range(len(timings))]
197+
sorted_data = sorted([(key, timings[key]) for key in timings], reverse=True, key=lambda tup: tup[1])
188198
heights = [item[1] for item in sorted_data]
189199
keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n') for item in sorted_data]
190200

@@ -193,33 +203,97 @@ def show_results(fname):
193203
plt_helper.plt.xticks(xcoords, keys, rotation=90)
194204
plt_helper.plt.ylabel('time (sec)')
195205

206+
# # save plot, beautify
207+
f = fname + '_timings'
208+
plt_helper.savefig(f)
209+
210+
assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file'
211+
assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file'
212+
assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file'
213+
214+
# set up plot for radii
215+
plt_helper.newfig(textwidth=238.96, scale=1.0)
216+
217+
exact_radii = []
218+
for key, item in results.items():
219+
computed_radii = sort_stats(filter_stats(item, type='computed_radius'), sortby='time')
220+
221+
xcoords = [item[0] for item in computed_radii]
222+
radii = [item[1] for item in computed_radii]
223+
plt_helper.plt.plot(xcoords, radii, label=key[0] + ' ' + key[1])
224+
225+
exact_radii = sort_stats(filter_stats(item, type='exact_radius'), sortby='time')
226+
227+
diff = np.array([abs(item0[1] - item1[1]) for item0, item1 in zip(exact_radii, computed_radii)])
228+
max_pos = int(np.argmax(diff))
229+
assert max(diff) < 0.07, 'ERROR: computed radius is too far away from exact radius, got %s' % max(diff)
230+
assert 0.028 < computed_radii[max_pos][0] < 0.03, \
231+
'ERROR: largest difference is at wrong time, got %s' % computed_radii[max_pos][0]
232+
233+
xcoords = [item[0] for item in exact_radii]
234+
radii = [item[1] for item in exact_radii]
235+
plt_helper.plt.plot(xcoords, radii, color='k', linestyle='--', linewidth=1, label='exact')
236+
237+
plt_helper.plt.ylabel('radius')
238+
plt_helper.plt.xlabel('time')
239+
plt_helper.plt.grid()
240+
plt_helper.plt.legend()
241+
242+
# save plot, beautify
243+
f = fname + '_radii'
244+
plt_helper.savefig(f)
245+
246+
assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file'
247+
assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file'
248+
assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file'
249+
250+
# set up plot for interface width
251+
plt_helper.newfig(textwidth=238.96, scale=1.0)
252+
253+
interface_width = []
254+
for key, item in results.items():
255+
interface_width = sort_stats(filter_stats(item, type='interface_width'), sortby='time')
256+
xcoords = [item[0] for item in interface_width]
257+
width = [item[1] for item in interface_width]
258+
plt_helper.plt.plot(xcoords, width, label=key[0] + ' ' + key[1])
259+
260+
xcoords = [item[0] for item in interface_width]
261+
init_width = [interface_width[0][1]] * len(xcoords)
262+
plt_helper.plt.plot(xcoords, init_width, color='k', linestyle='--', linewidth=1, label='exact')
263+
264+
plt_helper.plt.ylabel('interface width')
265+
plt_helper.plt.xlabel('time')
266+
plt_helper.plt.grid()
267+
plt_helper.plt.legend()
268+
196269
# save plot, beautify
197-
plt_helper.savefig(fname)
270+
f = fname + '_interface'
271+
plt_helper.savefig(f)
198272

199-
assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file'
200-
assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file'
201-
assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file'
273+
assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file'
274+
assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file'
275+
assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file'
202276

203277
return None
204278

205279

206-
def main():
280+
def main(cwd=''):
207281
"""
208282
Main driver
209283
210284
Args:
211285
cwd (str): current working directory (need this for testing)
212286
"""
213287

214-
# Loop over variants, exact and inexact solves
288+
# # Loop over variants, exact and inexact solves
215289
results = {}
216290
for variant in ['multi-implicit', 'semi-implicit', 'fully-implicit']:
217291

218292
results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False)
219293
results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True)
220294

221295
# dump result
222-
fname = 'data/timings_SDC_variants_AllenCahn'
296+
fname = cwd + 'data/results_SDC_variants_AllenCahn_1E-03'
223297
file = open(fname + '.pkl', 'wb')
224298
dill.dump(results, file)
225299
file.close()

pySDC/projects/TOMS/AllenCahn_monitor.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,18 @@ def pre_run(self, step, level_number):
3333
for r in rows:
3434
radius1 = max(radius1, abs(L.prob.xvalues[r]))
3535

36-
rows1 = np.where(L.u[0].values[int((L.prob.init[0])/2), :int((L.prob.init[0])/2)] > -0.99)
37-
rows2 = np.where(L.u[0].values[int((L.prob.init[0])/2), :int((L.prob.init[0])/2)] < 0.99)
36+
rows1 = np.where(L.u[0].values[int((L.prob.init[0]) / 2), :int((L.prob.init[0]) / 2)] > -0.99)
37+
rows2 = np.where(L.u[0].values[int((L.prob.init[0]) / 2), :int((L.prob.init[0]) / 2)] < 0.99)
3838
interface_width = (rows2[0][-1] - rows1[0][0]) * L.prob.dx / L.prob.params.eps
3939

40-
# print(radius, radius1)
4140
self.init_radius = L.prob.params.radius
4241

4342
self.add_to_stats(process=step.status.slot, time=L.time, level=-1, iter=step.status.iter,
4443
sweep=L.status.sweep, type='computed_radius', value=radius)
4544
self.add_to_stats(process=step.status.slot, time=L.time, level=-1, iter=step.status.iter,
4645
sweep=L.status.sweep, type='exact_radius', value=self.init_radius)
4746
self.add_to_stats(process=step.status.slot, time=L.time, level=-1, iter=step.status.iter,
48-
sweep=L.status.sweep, type='interfact_width', value=interface_width)
47+
sweep=L.status.sweep, type='interface_width', value=interface_width)
4948

5049
def post_step(self, step, level_number):
5150
"""
@@ -64,19 +63,13 @@ def post_step(self, step, level_number):
6463
radius = np.sqrt(c / np.pi) * L.prob.dx
6564

6665
exact_radius = np.sqrt(max(self.init_radius ** 2 - 2.0 * (L.time + L.dt), 0))
67-
# print(radius, exact_radius)
6866
rows1 = np.where(L.uend.values[int((L.prob.init[0]) / 2), :int((L.prob.init[0]) / 2)] > -0.99)
6967
rows2 = np.where(L.uend.values[int((L.prob.init[0]) / 2), :int((L.prob.init[0]) / 2)] < 0.99)
7068
interface_width = (rows2[0][-1] - rows1[0][0]) * L.prob.dx / L.prob.params.eps
7169

72-
self.add_to_stats(process=step.status.slot, time=L.time, level=-1, iter=step.status.iter,
70+
self.add_to_stats(process=step.status.slot, time=L.time + L.dt, level=-1, iter=step.status.iter,
7371
sweep=L.status.sweep, type='computed_radius', value=radius)
74-
self.add_to_stats(process=step.status.slot, time=L.time, level=-1, iter=step.status.iter,
72+
self.add_to_stats(process=step.status.slot, time=L.time + L.dt, level=-1, iter=step.status.iter,
7573
sweep=L.status.sweep, type='exact_radius', value=exact_radius)
76-
self.add_to_stats(process=step.status.slot, time=L.time, level=-1, iter=step.status.iter,
74+
self.add_to_stats(process=step.status.slot, time=L.time + L.dt, level=-1, iter=step.status.iter,
7775
sweep=L.status.sweep, type='interface_width', value=interface_width)
78-
79-
# def post_run(self, step, level_number):
80-
# super(monitor, self).post_run(step, level_number)
81-
# plt.show()
82-
# plt.savefig('allen-cahn.png')

0 commit comments

Comments
 (0)