Skip to content

Commit 73bee65

Browse files
committed
Reframe residual test
1 parent 6e9d8cf commit 73bee65

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

tests/unit/test_groupfit.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,29 +125,32 @@ def run_grouper_fit_test(irrad_type: str, grouper: Grouper,
125125
if irrad_type == 'saturation':
126126
assert np.isclose(func_counts[0], initial_count_rate, rtol=1e-4), "Initial count rate mismatch"
127127

128-
assert np.allclose(func_counts, counts, atol=1e-2, rtol=1e-2), f'{irrad_type.capitalize()} counts mismatch between hand calculation and function evaluation'
129128

130129
count_data = {
131130
'times': times,
132131
'counts': counts,
133132
'sigma counts': counts*1e-12
134133
}
135134

135+
assert np.allclose(func_counts, counts, atol=1e-2, rtol=1e-2), f'{irrad_type.capitalize()} counts mismatch between hand calculation and function evaluation'
136136
assert grouper.irrad_type == irrad_type
137137
data = grouper._nonlinear_least_squares(count_data=count_data, set_refined_fiss=False)
138138
test_yields = [data[key]['yield'] for key in range(grouper.num_groups)]
139139
test_half_lives = [data[key]['half_life'] for key in range(grouper.num_groups)]
140140
parameters = test_yields + test_half_lives
141141
adjusted_parameters = grouper._restructure_intermediate_yields(parameters)
142-
residual_known = np.linalg.norm(grouper._residual_function(adjusted_parameters, times, counts, counts*1e-12, [], [], [], fit_func))
143-
residual_previous = np.linalg.norm(grouper._residual_function(adjusted_parameters, times, func_counts, func_counts*1e-12, [], [], [], fit_func))
142+
residual_counts = np.linalg.norm(grouper._residual_function(adjusted_parameters, times, counts, counts*1e-12, [], [], [], fit_func))
143+
func_counts = fit_func(times, adjusted_parameters)
144+
residual_fit = np.linalg.norm(grouper._residual_function(adjusted_parameters, times, func_counts, func_counts*1e-12, [], [], [], fit_func))
144145
grouper.logger.error(f'{base_parameters = }')
145146
grouper.logger.error(f'{base_inter_parameters = }')
146147
grouper.logger.error(f'{parameters = }')
147148
grouper.logger.error(f'{adjusted_parameters = }')
148-
grouper.logger.error(f'{residual_known = }')
149+
grouper.logger.error(f'{np.mean(func_counts - counts) = }')
150+
grouper.logger.error(f'{residual_counts = }')
151+
grouper.logger.error(f'{residual_fit = }')
149152

150-
assert np.isclose(residual_known, residual_previous, rtol=1e-1), "Same counts should have the same residual"
153+
assert residual_counts > residual_fit, "Fit residual should be zero"
151154

152155
original_half_lives = np.asarray(half_lives)
153156
original_yields = np.asarray(yields)
@@ -175,31 +178,27 @@ def test_grouper_saturation_noex_fitting():
175178
grouper.t_ex = 0
176179
run_grouper_fit_test('saturation', grouper)
177180

178-
@pytest.mark.slow
179181
def test_grouper_saturation_noex_short_fitting_few():
180182
input_path = './tests/unit/input/input.json'
181183
grouper = Grouper(input_path)
182184
grouper.t_ex = 0
183185
grouper.t_net = 30
184186
run_grouper_fit_test('saturation', grouper, 'few_groups')
185187

186-
@pytest.mark.slow
187188
def test_grouper_saturation_ex_short_fitting_few():
188189
input_path = './tests/unit/input/input.json'
189190
grouper = Grouper(input_path)
190191
grouper.t_ex = 10
191192
grouper.t_net = 30
192193
run_grouper_fit_test('saturation_ex', grouper, 'few_groups')
193194

194-
@pytest.mark.slow
195195
def test_grouper_intermediate_noex_short_fitting_few():
196196
input_path = './tests/unit/input/input.json'
197197
grouper = Grouper(input_path)
198198
grouper.t_ex = 0
199199
grouper.t_net = 30
200200
run_grouper_fit_test('intermediate', grouper, 'few_groups')
201201

202-
@pytest.mark.slow
203202
def test_grouper_intermediate_ex_short_fitting_few():
204203
input_path = './tests/unit/input/input.json'
205204
grouper = Grouper(input_path)
@@ -242,7 +241,6 @@ def test_grouper_intermediate_ex_fitting_standard_params():
242241
grouper.t_ex = 10
243242
run_grouper_fit_test('intermediate_ex', grouper, 'standard')
244243

245-
@pytest.mark.slow
246244
def test_grouper_intermediate_ex_short_fitting_standard_params():
247245
input_path = './tests/unit/input/input.json'
248246
grouper = Grouper(input_path)

0 commit comments

Comments
 (0)