11import sys
2-
32from mpi4py import MPI
4- import time
5- from petsc4py import PETSc
6-
73import numpy as np
84
9- from pySDC .implementations .problem_classes .GrayScott_2D_PETSc_implicit_periodic import petsc_grayscott
10- from pySDC .implementations .datatype_classes .petsc_dmda_grid import petsc_data
5+ from pySDC .implementations .problem_classes .HeatEquation_2D_PETSc_forced import heat2d_petsc_forced
6+ from pySDC .implementations .datatype_classes .petsc_dmda_grid import petsc_data , rhs_imex_petsc_data
117from pySDC .implementations .collocation_classes .gauss_radau_right import CollGaussRadau_Right
12- from pySDC .implementations .sweeper_classes .generic_implicit import generic_implicit
8+ from pySDC .implementations .sweeper_classes .imex_1st_order import imex_1st_order
139from pySDC .implementations .transfer_classes .TransferPETScDMDA import mesh_to_mesh_petsc_dmda
1410from pySDC .implementations .controller_classes .allinclusive_multigrid_MPI import allinclusive_multigrid_MPI
15- from pySDC .implementations .controller_classes .allinclusive_multigrid_nonMPI import allinclusive_multigrid_nonMPI
1611
1712from pySDC .helpers .stats_helper import filter_stats , sort_stats
1813
19- def main ():
2014
15+ def main ():
16+ """
17+ Program to demonstrate usage of PETSc data structures and spatial parallelization,
18+ combined with parallelization in time.
19+ """
2120 # set MPI communicator
2221 comm = MPI .COMM_WORLD
2322
2423 world_rank = comm .Get_rank ()
2524 world_size = comm .Get_size ()
2625
26+ # split world communicator to create space-communicators
2727 if len (sys .argv ) == 2 :
2828 color = int (world_rank / int (sys .argv [1 ]))
2929 else :
3030 color = int (world_rank / 1 )
31-
3231 space_comm = comm .Split (color = color )
33- space_rank = space_comm .Get_rank ()
34- space_size = space_comm .Get_size ()
3532
33+ # split world communicator to create time-communicators
3634 if len (sys .argv ) == 2 :
3735 color = int (world_rank % int (sys .argv [1 ]))
3836 else :
3937 color = int (world_rank / world_size )
40-
4138 time_comm = comm .Split (color = color )
42- time_rank = time_comm .Get_rank ()
43- time_size = time_comm .Get_size ()
44-
45- print ("IDs (world, space, time): %i / %i -- %i / %i -- %i / %i" % (world_rank , world_size , space_rank , space_size ,
46- time_rank , time_size ))
4739
4840 # initialize level parameters
4941 level_params = dict ()
5042 level_params ['restol' ] = 1E-08
51- level_params ['dt' ] = 1.0
43+ level_params ['dt' ] = 0.125
5244 level_params ['nsweeps' ] = [1 ]
5345
5446 # initialize sweeper parameters
@@ -60,51 +52,47 @@ def main():
6052
6153 # initialize problem parameters
6254 problem_params = dict ()
63- problem_params ['Du' ] = 1.0
64- problem_params ['Dv' ] = 0.01
65- problem_params ['A' ] = 0.09
66- problem_params ['B' ] = 0.086
67- problem_params ['nvars' ] = [(127 , 127 )] # number of degrees of freedom for each level
68- problem_params ['comm' ] = space_comm
69- problem_params ['sol_tol' ] = 1E-10
70- problem_params ['sol_maxiter' ] = 100
55+ problem_params ['nu' ] = 1.0 # diffusion coefficient
56+ problem_params ['freq' ] = 2 # frequency for the test value
57+ problem_params ['nvars' ] = [(129 , 129 ), (65 , 65 )] # number of degrees of freedom for each level
58+ problem_params ['comm' ] = space_comm # pass space-communicator to problem class
59+ problem_params ['sol_tol' ] = 1E-12 # set tolerance to PETSc' linear solver
7160
7261 # initialize step parameters
7362 step_params = dict ()
7463 step_params ['maxiter' ] = 50
7564
7665 # initialize space transfer parameters
77- # space_transfer_params = dict()
78- # space_transfer_params['rorder'] = 2
79- # space_transfer_params['iorder'] = 2
80- # space_transfer_params['periodic'] = True
66+ space_transfer_params = dict ()
67+ space_transfer_params ['rorder' ] = 2
68+ space_transfer_params ['iorder' ] = 2
69+ space_transfer_params ['periodic' ] = False
8170
8271 # initialize controller parameters
8372 controller_params = dict ()
8473 controller_params ['logger_level' ] = 20
8574 # controller_params['predict'] = False
86- # controller_params['hook_class'] = error_output
8775
8876 # fill description dictionary for easy step instantiation
8977 description = dict ()
90- description ['problem_class' ] = petsc_grayscott # pass problem class
78+ description ['problem_class' ] = heat2d_petsc_forced # pass problem class
9179 description ['problem_params' ] = problem_params # pass problem parameters
92- description ['dtype_u' ] = petsc_data # pass data type for u
93- description ['dtype_f' ] = petsc_data # pass data type for f
94- description ['sweeper_class' ] = generic_implicit # pass sweeper (see part B)
80+ description ['dtype_u' ] = petsc_data # pass PETSc data type for u
81+ description ['dtype_f' ] = rhs_imex_petsc_data # pass PETSc data type for f
82+ description ['sweeper_class' ] = imex_1st_order # pass sweeper (see part B)
9583 description ['sweeper_params' ] = sweeper_params # pass sweeper parameters
9684 description ['level_params' ] = level_params # pass level parameters
9785 description ['step_params' ] = step_params # pass step parameters
9886 description ['space_transfer_class' ] = mesh_to_mesh_petsc_dmda # pass spatial transfer class
99- # description['space_transfer_params'] = space_transfer_params # pass paramters for spatial transfer
87+ description ['space_transfer_params' ] = space_transfer_params # pass paramters for spatial transfer
10088
10189 # set time parameters
10290 t0 = 0.0
103- Tend = 1.0
91+ Tend = 0.25
10492
10593 # instantiate controller
106- controller = allinclusive_multigrid_MPI (controller_params = controller_params , description = description , comm = time_comm )
107- # controller = allinclusive_multigrid_nonMPI(num_procs=2, controller_params=controller_params, description=description )
94+ controller = allinclusive_multigrid_MPI (controller_params = controller_params , description = description ,
95+ comm = time_comm )
10896
10997 # get initial values on finest level
11098 P = controller .S .levels [0 ].prob
@@ -117,8 +105,6 @@ def main():
117105 uex = P .u_exact (Tend )
118106 err = abs (uex - uend )
119107
120- print (err )
121-
122108 # filter statistics by type (number of iterations)
123109 filtered_stats = filter_stats (stats , type = 'niter' )
124110
@@ -143,7 +129,12 @@ def main():
143129
144130 timing = sort_stats (filter_stats (stats , type = 'timing_run' ), sortby = 'time' )
145131
146- print (timing )
132+ print ('Time to solution: %6.4f sec.' % timing [0 ][1 ])
133+ print ('Error vs. PDE solution: %6.4e' % err )
134+ print ()
135+
136+ assert err < 2E-04 , 'ERROR: did not match error tolerance, got %s' % err
137+ assert np .mean (niters ) <= 12 , 'ERROR: number of iterations is too high, got %s' % np .mean (niters )
147138
148139
149140if __name__ == "__main__" :
0 commit comments