11import pySDC .helpers .plot_helper as plt_helper
2+ import matplotlib .ticker as ticker
23
34import dill
45import os
@@ -51,7 +52,7 @@ def setup_parameters():
5152 problem_params = dict ()
5253 problem_params ['nu' ] = 2
5354 problem_params ['nvars' ] = [(128 , 128 )]
54- problem_params ['eps' ] = [0.0390625 ]
55+ problem_params ['eps' ] = [0.04 ]
5556 problem_params ['newton_maxiter' ] = 100
5657 problem_params ['newton_tol' ] = 1E-09
5758 problem_params ['lin_tol' ] = 1E-10
@@ -110,13 +111,25 @@ def run_SDC_variant(variant=None, inexact=False):
110111 description ['sweeper_class' ] = imex_1st_order
111112 if inexact :
112113 description ['problem_params' ]['lin_maxiter' ] = 10
114+ elif variant == 'semi-implicit_v2' :
115+ description ['problem_class' ] = allencahn_semiimplicit_v2
116+ description ['dtype_f' ] = rhs_imex_mesh
117+ description ['sweeper_class' ] = imex_1st_order
118+ if inexact :
119+ description ['problem_params' ]['newton_maxiter' ] = 1
113120 elif variant == 'multi-implicit' :
114121 description ['problem_class' ] = allencahn_multiimplicit
115122 description ['dtype_f' ] = rhs_comp2_mesh
116123 description ['sweeper_class' ] = multi_implicit
117124 if inexact :
118125 description ['problem_params' ]['newton_maxiter' ] = 1
119126 description ['problem_params' ]['lin_maxiter' ] = 10
127+ elif variant == 'multi-implicit_v2' :
128+ description ['problem_class' ] = allencahn_multiimplicit_v2
129+ description ['dtype_f' ] = rhs_comp2_mesh
130+ description ['sweeper_class' ] = multi_implicit
131+ if inexact :
132+ description ['problem_params' ]['newton_maxiter' ] = 1
120133 else :
121134 raise NotImplemented ('Wrong variant specified, got %s' % variant )
122135
@@ -188,23 +201,38 @@ def show_results(fname, cwd=''):
188201 plt_helper .setup_mpl ()
189202
190203 # set up plot for timings
191- plt_helper .newfig (textwidth = 238.96 , scale = 1.0 )
204+ fig , ax1 = plt_helper .newfig (textwidth = 238.96 , scale = 1.5 , ratio = 0.4 )
192205
193206 timings = {}
207+ niters = {}
194208 for key , item in results .items ():
195209 timings [key ] = sort_stats (filter_stats (item , type = 'timing_run' ), sortby = 'time' )[0 ][1 ]
210+ iter_counts = sort_stats (filter_stats (item , type = 'niter' ), sortby = 'time' )
211+ niters [key ] = np .mean (np .array ([item [1 ] for item in iter_counts ]))
196212
197213 xcoords = [i for i in range (len (timings ))]
198- sorted_data = sorted ([(key , timings [key ]) for key in timings ], reverse = True , key = lambda tup : tup [1 ])
199- heights = [item [1 ] for item in sorted_data ]
200- keys = [(item [0 ][1 ] + ' ' + item [0 ][0 ]).replace ('-' , '\n ' ) for item in sorted_data ]
214+ sorted_timings = sorted ([(key , timings [key ]) for key in timings ], reverse = True , key = lambda tup : tup [1 ])
215+ sorted_niters = [(k , niters [k ]) for k in [key [0 ] for key in sorted_timings ]]
216+ heights_timings = [item [1 ] for item in sorted_timings ]
217+ heights_niters = [item [1 ] for item in sorted_niters ]
218+ keys = [(item [0 ][1 ] + ' ' + item [0 ][0 ]).replace ('-' , '\n ' ).replace ('_v2' , ' mod.' ) for item in sorted_timings ]
219+
220+ ax1 .bar (xcoords , heights_timings , align = 'edge' , width = - 0.3 , label = 'timings (left axis)' )
221+ ax1 .set_ylabel ('time (sec)' )
201222
202- plt_helper .plt .bar (xcoords , heights , align = 'center' )
223+ ax2 = ax1 .twinx ()
224+ ax2 .bar (xcoords , heights_niters , color = 'r' , align = 'edge' , width = 0.3 , label = 'iterations (right axis)' )
225+ ax2 .set_ylabel ('mean number of iterations' )
203226
204- plt_helper . plt . xticks (xcoords , keys , rotation = 90 )
205- plt_helper . plt . ylabel ( 'time (sec) ' )
227+ ax1 . set_xticks (xcoords )
228+ ax1 . set_xticklabels ( keys , rotation = 90 , ha = 'center ' )
206229
207- # # save plot, beautify
230+ # ask matplotlib for the plotted objects and their labels
231+ lines , labels = ax1 .get_legend_handles_labels ()
232+ lines2 , labels2 = ax2 .get_legend_handles_labels ()
233+ ax2 .legend (lines + lines2 , labels + labels2 , loc = 0 )
234+
235+ # save plot, beautify
208236 f = fname + '_timings'
209237 plt_helper .savefig (f )
210238
@@ -213,15 +241,16 @@ def show_results(fname, cwd=''):
213241 assert os .path .isfile (f + '.png' ), 'ERROR: plotting did not create PNG file'
214242
215243 # set up plot for radii
216- plt_helper .newfig (textwidth = 238.96 , scale = 1.0 )
244+ fig , ax = plt_helper .newfig (textwidth = 238.96 , scale = 1.0 )
217245
218246 exact_radii = []
219247 for key , item in results .items ():
220248 computed_radii = sort_stats (filter_stats (item , type = 'computed_radius' ), sortby = 'time' )
221249
222250 xcoords = [item0 [0 ] for item0 in computed_radii ]
223251 radii = [item0 [1 ] for item0 in computed_radii ]
224- plt_helper .plt .plot (xcoords , radii , label = key [0 ] + ' ' + key [1 ])
252+ if key [0 ] + ' ' + key [1 ] == 'fully-implicit exact' :
253+ ax .plot (xcoords , radii , label = (key [0 ] + ' ' + key [1 ]).replace ('_v2' , ' mod.' ))
225254
226255 exact_radii = sort_stats (filter_stats (item , type = 'exact_radius' ), sortby = 'time' )
227256
@@ -233,12 +262,13 @@ def show_results(fname, cwd=''):
233262
234263 xcoords = [item [0 ] for item in exact_radii ]
235264 radii = [item [1 ] for item in exact_radii ]
236- plt_helper . plt .plot (xcoords , radii , color = 'k' , linestyle = '--' , linewidth = 1 , label = 'exact' )
265+ ax .plot (xcoords , radii , color = 'k' , linestyle = '--' , linewidth = 1 , label = 'exact' )
237266
238- plt_helper .plt .ylabel ('radius' )
239- plt_helper .plt .xlabel ('time' )
240- plt_helper .plt .grid ()
241- plt_helper .plt .legend ()
267+ ax .yaxis .set_major_formatter (ticker .FormatStrFormatter ('%1.2f' ))
268+ ax .set_ylabel ('radius' )
269+ ax .set_xlabel ('time' )
270+ ax .grid ()
271+ ax .legend (loc = 3 )
242272
243273 # save plot, beautify
244274 f = fname + '_radii'
@@ -249,23 +279,25 @@ def show_results(fname, cwd=''):
249279 assert os .path .isfile (f + '.png' ), 'ERROR: plotting did not create PNG file'
250280
251281 # set up plot for interface width
252- plt_helper .newfig (textwidth = 238.96 , scale = 1.0 )
282+ fig , ax = plt_helper .newfig (textwidth = 238.96 , scale = 1.0 )
253283
254284 interface_width = []
255285 for key , item in results .items ():
256286 interface_width = sort_stats (filter_stats (item , type = 'interface_width' ), sortby = 'time' )
257287 xcoords = [item [0 ] for item in interface_width ]
258288 width = [item [1 ] for item in interface_width ]
259- plt_helper .plt .plot (xcoords , width , label = key [0 ] + ' ' + key [1 ])
289+ if key [0 ] + ' ' + key [1 ] == 'fully-implicit exact' :
290+ ax .plot (xcoords , width , label = key [0 ] + ' ' + key [1 ])
260291
261292 xcoords = [item [0 ] for item in interface_width ]
262293 init_width = [interface_width [0 ][1 ]] * len (xcoords )
263- plt_helper . plt .plot (xcoords , init_width , color = 'k' , linestyle = '--' , linewidth = 1 , label = 'exact' )
294+ ax .plot (xcoords , init_width , color = 'k' , linestyle = '--' , linewidth = 1 , label = 'exact' )
264295
265- plt_helper .plt .ylabel ('interface width' )
266- plt_helper .plt .xlabel ('time' )
267- plt_helper .plt .grid ()
268- plt_helper .plt .legend ()
296+ ax .yaxis .set_major_formatter (ticker .FormatStrFormatter ('%1.2f' ))
297+ ax .set_ylabel ('interface' )
298+ ax .set_xlabel ('time' )
299+ ax .grid ()
300+ ax .legend (loc = 3 )
269301
270302 # save plot, beautify
271303 f = fname + '_interface'
@@ -286,19 +318,19 @@ def main(cwd=''):
286318 cwd (str): current working directory (need this for testing)
287319 """
288320
289- # Loop over variants, exact and inexact solves
290- results = {}
291- for variant in ['multi-implicit' , 'semi-implicit' , 'fully-implicit' ]:
292-
293- results [(variant , 'exact' )] = run_SDC_variant (variant = variant , inexact = False )
294- results [(variant , 'inexact' )] = run_SDC_variant (variant = variant , inexact = True )
295-
296- # dump result
321+ # # Loop over variants, exact and inexact solves
322+ # results = {}
323+ # for variant in ['multi-implicit', 'semi-implicit', 'fully-implicit', 'semi-implicit_v2', 'multi-implicit_v2 ']:
324+ #
325+ # results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False)
326+ # results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True)
327+ #
328+ # # dump result
297329 fname = 'data/results_SDC_variants_AllenCahn_1E-03'
298- file = open (cwd + fname + '.pkl' , 'wb' )
299- dill .dump (results , file )
300- file .close ()
301- assert os .path .isfile (cwd + fname + '.pkl' ), 'ERROR: dill did not create file'
330+ # file = open(cwd + fname + '.pkl', 'wb')
331+ # dill.dump(results, file)
332+ # file.close()
333+ # assert os.path.isfile(cwd + fname + '.pkl'), 'ERROR: dill did not create file'
302334
303335 # visualize
304336 show_results (fname , cwd = cwd )
0 commit comments