Skip to content

Commit ec6c1af

Browse files
committed
Respect ExpData settings in gradient checks
Previously, Model settings were used, independently of whether ExpData had different settings.
1 parent 25adebf commit ec6c1af

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

python/sdist/amici/gradient_check.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,29 @@ def check_finite_difference(
6666
finite difference step-size
6767
6868
"""
69+
p = copy.deepcopy(x0)
70+
plist = [ip]
71+
72+
# store original settings and apply new ones
6973
og_sensitivity_order = solver.get_sensitivity_order()
7074
og_parameters = model.get_parameters()
7175
og_plist = model.get_parameter_list()
7276
if edata:
7377
og_eplist = edata.plist
78+
og_eparameters = edata.parameters
7479

75-
# sensitivity
76-
p = copy.deepcopy(x0)
77-
plist = [ip]
80+
edata.plist = plist
81+
# we always set parameters via the model below
82+
edata.parameters = []
83+
pscale = (
84+
edata.pscale if len(edata.pscale) else model.get_parameter_scale()
85+
)
86+
else:
87+
pscale = model.get_parameter_scale()
88+
model.set_parameter_list(plist)
7889

90+
model.set_parameter_scale(pscale)
7991
model.set_parameters(p)
80-
model.set_parameter_list(plist)
81-
if edata:
82-
edata.plist = plist
8392

8493
# simulation with gradient
8594
if int(og_sensitivity_order) < int(SensitivityOrder.first):
@@ -93,8 +102,7 @@ def check_finite_difference(
93102

94103
pf = copy.deepcopy(x0)
95104
pb = copy.deepcopy(x0)
96-
pscale = model.get_parameter_scale()[ip]
97-
if x0[ip] == 0 or pscale != int(ParameterScaling.none):
105+
if x0[ip] == 0 or pscale[ip] != int(ParameterScaling.none):
98106
pf[ip] += epsilon / 2
99107
pb[ip] -= epsilon / 2
100108
else:
@@ -142,6 +150,7 @@ def check_finite_difference(
142150
model.set_parameter_list(og_plist)
143151
if edata:
144152
edata.plist = og_eplist
153+
edata.parameters = og_eparameters
145154

146155

147156
def check_derivatives(
@@ -160,7 +169,8 @@ def check_derivatives(
160169
161170
:param model: amici model
162171
:param solver: amici solver
163-
:param edata: exp data
172+
:param edata: ExpData instance. If provided, ExpData settings will
173+
override model settings where applicable (`plist`, `parmeters`, ...).
164174
:param atol: absolute tolerance for comparison
165175
:param rtol: relative tolerance for comparison
166176
:param epsilon: finite difference step-size
@@ -169,7 +179,10 @@ def check_derivatives(
169179
are zero
170180
:param skip_fields: list of fields to skip
171181
"""
172-
p = np.array(model.get_parameters())
182+
if edata and edata.parameters:
183+
p = np.array(edata.parameters)
184+
else:
185+
p = np.array(model.get_parameters())
173186

174187
og_sens_order = solver.get_sensitivity_order()
175188

tests/petab_test_suite/test_petab_suite.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,6 @@ def check_derivatives(
203203
petab_problem=problem,
204204
problem_parameters=problem_parameters,
205205
):
206-
# check_derivatives does currently not support parameters in ExpData
207-
# set parameter scales before setting parameter values!
208-
model.set_parameter_scale(edata.pscale)
209-
model.set_parameters(edata.parameters)
210-
edata.parameters = []
211-
edata.pscale = amici.parameter_scaling_from_int_vector([])
212206
amici_check_derivatives(model, solver, edata)
213207

214208

0 commit comments

Comments
 (0)