Skip to content

Commit 3171b39

Browse files
committed
Split the tests
1 parent cdb15f3 commit 3171b39

File tree

1 file changed

+31
-43
lines changed

1 file changed

+31
-43
lines changed

pySDC/projects/Resilience/tests/test_fault_injection.py

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -153,29 +153,29 @@ def test_fault_injection():
153153

154154
@pytest.mark.mpi4py
155155
@pytest.mark.slow
156-
@pytest.mark.parametrize("numprocs", [4])
157-
def test_fault_stats(numprocs):
156+
@pytest.mark.parametrize('strategy_name', ['base', 'adaptivity', 'kAdaptivity', 'HotRod'])
157+
def test_fault_stats(strategy_name):
158158
"""
159159
Test generation of fault statistics and their recovery rates
160160
"""
161161
import numpy as np
162+
from pySDC.projects.Resilience.strategies import (
163+
BaseStrategy,
164+
AdaptivityStrategy,
165+
kAdaptivityStrategy,
166+
HotRodStrategy,
167+
)
162168

163-
# Set python path once
164-
my_env = os.environ.copy()
165-
my_env['PYTHONPATH'] = '../../..:.'
166-
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'
167-
168-
cmd = f"mpirun -np {numprocs} python {__file__} --test-fault-stats".split()
169+
strategies = {
170+
'base': BaseStrategy,
171+
'adaptivity': AdaptivityStrategy,
172+
'kAdaptivity': kAdaptivityStrategy,
173+
'HotRod': HotRodStrategy,
174+
}
169175

170-
p = subprocess.Popen(cmd, env=my_env, cwd=".")
176+
strategy = strategies[strategy_name]()
171177

172-
p.wait()
173-
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (
174-
p.returncode,
175-
numprocs,
176-
)
177-
178-
stats = generate_stats(True)
178+
stats = generate_stats(strategy, True)
179179

180180
# test number of possible combinations for faults
181181
expected_max_combinations = 3840
@@ -193,26 +193,25 @@ def test_fault_stats(numprocs):
193193
}
194194
stats.get_recovered()
195195

196-
for strategy in stats.strategies:
197-
dat = stats.load(strategy=strategy, faults=True)
198-
fixable_mask = stats.get_fixable_faults_only(strategy)
199-
recovered_mask = stats.get_mask(strategy=strategy, key='recovered', op='eq', val=True)
200-
index = stats.get_index(mask=fixable_mask)
196+
dat = stats.load(strategy=strategy, faults=True)
197+
fixable_mask = stats.get_fixable_faults_only(strategy)
198+
recovered_mask = stats.get_mask(strategy=strategy, key='recovered', op='eq', val=True)
199+
index = stats.get_index(mask=fixable_mask)
201200

202-
assert all(fixable_mask == [False, True]), "Error in generating mask of fixable faults"
203-
assert all(index == [1]), "Error when converting to index"
201+
assert all(fixable_mask == [False, True]), "Error in generating mask of fixable faults"
202+
assert all(index == [1]), "Error when converting to index"
204203

205-
combinations = np.array(stats.get_combination_counts(dat, keys=['bit'], mask=fixable_mask))
206-
assert all(combinations == [1.0, 1.0]), "Error when counting combinations"
204+
combinations = np.array(stats.get_combination_counts(dat, keys=['bit'], mask=fixable_mask))
205+
assert all(combinations == [1.0, 1.0]), "Error when counting combinations"
207206

208-
recovered = len(dat['recovered'][recovered_mask])
209-
crashed = len(dat['error'][dat['error'] == np.inf]) # on some systems the last run crashes...
210-
assert (
211-
recovered >= recovered_reference[strategy.name] - crashed
212-
), f'Expected {recovered_reference[strategy.name]} recovered faults, but got {recovered} recovered faults in {strategy.name} strategy!'
207+
recovered = len(dat['recovered'][recovered_mask])
208+
crashed = len(dat['error'][dat['error'] == np.inf]) # on some systems the last run crashes...
209+
assert (
210+
recovered >= recovered_reference[strategy.name] - crashed
211+
), f'Expected {recovered_reference[strategy.name]} recovered faults, but got {recovered} recovered faults in {strategy.name} strategy!'
213212

214213

215-
def generate_stats(load=False):
214+
def generate_stats(strategy, load=False):
216215
"""
217216
Generate stats to check the recovery rate
218217
@@ -222,12 +221,6 @@ def generate_stats(load=False):
222221
Returns:
223222
Object containing the stats
224223
"""
225-
from pySDC.projects.Resilience.strategies import (
226-
BaseStrategy,
227-
AdaptivityStrategy,
228-
kAdaptivityStrategy,
229-
HotRodStrategy,
230-
)
231224
from pySDC.projects.Resilience.fault_stats import (
232225
FaultStats,
233226
)
@@ -242,12 +235,7 @@ def generate_stats(load=False):
242235
recovery_thresh=1.1,
243236
num_procs=1,
244237
mode='random',
245-
strategies=[
246-
BaseStrategy(),
247-
AdaptivityStrategy(),
248-
kAdaptivityStrategy(),
249-
HotRodStrategy(),
250-
],
238+
strategies=[strategy],
251239
stats_path='data',
252240
)
253241
stats.run_stats_generation(runs=2, step=1)

0 commit comments

Comments
 (0)