@@ -38,6 +38,10 @@ def create_directories():
3838def test_run (tmpdir ):
3939 from pySDC .projects .GPU .configs .base_config import get_config
4040 from pySDC .projects .GPU .run_experiment import run_experiment
41+ from pySDC .helpers .stats_helper import get_sorted
42+ from pySDC .helpers .fieldsIO import FieldsIO
43+ import pickle
44+ import numpy as np
4145
4246 args = {
4347 'config' : 'RBC' ,
@@ -50,6 +54,7 @@ def test_run(tmpdir):
5054 'restart_idx' : 0 ,
5155 }
5256 config = get_config (args )
57+ type(config ).base_path = args ['o' ]
5358
5459 def get_LogToFile (self , * args , ** kwargs ):
5560 if self .comms [1 ].rank > 0 :
@@ -63,89 +68,39 @@ def get_LogToFile(self, *args, **kwargs):
6368 return LogToFile
6469
6570 type(config ).get_LogToFile = get_LogToFile
71+ stats_path = f'{ config .base_path } /data/{ config .get_path ()} -stats-whole-run.pickle'
72+ file_path = config .get_file_name ()
6673
6774 # first run for a short time
6875 dt = config .get_description ()['level_params' ]['dt' ]
6976 config .Tend = 2 * dt
7077 run_experiment (args , config )
7178
79+ # check data
80+ data = FieldsIO .fromFile (file_path )
81+ with open (stats_path , 'rb' ) as file :
82+ stats = pickle .load (file )
7283
73- @pytest .mark .order (1 )
74- def test_run_experiment (restart_idx = 0 ):
75- from pySDC .projects .GPU .configs .base_config import Config
76- from pySDC .projects .GPU .run_experiment import run_experiment , parse_args
77- from pySDC .helpers .stats_helper import get_sorted
78- import pickle
79- import numpy as np
80-
81- create_directories ()
82-
83- class VdPConfig (Config ):
84- sweeper_type = 'generic_implicit'
85- Tend = 1
86-
87- def get_description (self , * args , ** kwargs ):
88- from pySDC .implementations .problem_classes .Van_der_Pol_implicit import vanderpol
89-
90- desc = super ().get_description (* args , ** kwargs )
91- desc ['problem_class' ] = vanderpol
92- desc ['problem_params' ].pop ('useGPU' )
93- desc ['problem_params' ].pop ('comm' )
94- desc ['sweeper_params' ]['num_nodes' ] = 2
95- desc ['sweeper_params' ]['quad_type' ] = 'RADAU-RIGHT'
96- desc ['sweeper_params' ]['QI' ] = 'LU'
97- desc ['level_params' ]['dt' ] = 0.1
98- desc ['step_params' ]['maxiter' ] = 3
99- return desc
100-
101- def get_LogToFile (self , ranks = None ):
102- from pySDC .implementations .hooks .log_solution import LogToPickleFileAfterXS as LogToFile
103-
104- LogToFile .path = './data/'
105- LogToFile .file_name = f'{ self .get_path (ranks = ranks )} -solution'
106- LogToFile .time_increment = 2e-1
107-
108- def logging_condition (L ):
109- sweep = L .sweep
110- if hasattr (sweep , 'comm' ):
111- if sweep .comm .rank == sweep .comm .size - 1 :
112- return True
113- else :
114- return False
115- else :
116- return True
117-
118- LogToFile .logging_condition = logging_condition
119- return LogToFile
84+ dts = get_sorted (stats , type = 'dt' )
85+ assert len (dts ) == len (data .times ) - 1
86+ assert np .allclose (data .times , 0.1 * np .arange (3 )), 'Did not record solutions at expected times'
12087
121- args = {
122- 'procs' : [1 , 1 , 1 ],
123- 'useGPU' : False ,
124- 'res' : - 1 ,
125- 'logger_level' : 15 ,
126- 'restart_idx' : restart_idx ,
127- 'mode' : 'run' ,
128- }
129- config = VdPConfig (args )
88+ # restart run
89+ args ['restart_idx' ] = - 1
90+ config .Tend = 4 * dt
13091 run_experiment (args , config )
13192
132- with open (f'data/{ config .get_path ()} -stats-whole-run.pickle' , 'rb' ) as file :
93+ # check data
94+ data = FieldsIO .fromFile (file_path )
95+ with open (stats_path , 'rb' ) as file :
13396 stats = pickle .load (file )
97+ dts = get_sorted (stats , type = 'dt' )
13498
135- k_Newton = get_sorted (stats , type = 'work_newton' )
136- assert len (k_Newton ) == 10
137- assert sum ([me [1 ] for me in k_Newton ]) == 91
138-
139-
140- @pytest .mark .order (2 )
141- def test_restart ():
142- test_run_experiment (3 )
99+ assert len (dts ) == len (data .times ) - 1
100+ assert np .allclose (data .times , 0.1 * np .arange (5 )), 'Did not record solutions at expected times after restart'
143101
144102
145103if __name__ == '__main__' :
146- test_run ('.' )
147- exit ()
148-
149104 import argparse
150105
151106 parser = argparse .ArgumentParser ()
@@ -155,9 +110,5 @@ def test_restart():
155110
156111 if args .test == 'get_comms' :
157112 test_get_comms (False )
158- elif args .test == 'run_experiment' :
159- test_run_experiment ()
160- elif args .test == 'restart' :
161- test_restart ()
162113 else :
163114 raise NotImplementedError
0 commit comments