Skip to content

Commit 38f874a

Browse files
committed
Fix sx_ss for log-transformed parameters (#2864)
* So far, `sx_ss` was incorrectly the sensitivities w.r.t. the unscaled parameters. Fixed here. * So far, pre-equilibration steady-state sensitivities were not included in the finite difference checks. Now they are. * Add option to skip specific ReturnData fields in the finite difference check
1 parent 83c432c commit 38f874a

File tree

3 files changed

+24
-26
lines changed

3 files changed

+24
-26
lines changed

python/sdist/amici/gradient_check.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def check_finite_difference(
116116
for field in fields:
117117
sensi_raw = rdata[f"s{field}"]
118118
fd = (rdataf[field] - rdatab[field]) / (pf[ip] - pb[ip])
119-
if len(sensi_raw.shape) == 1:
119+
if len(sensi_raw.shape) == 1 or field == "x_ss":
120120
sensi = sensi_raw[0]
121121
elif len(sensi_raw.shape) == 2:
122122
sensi = sensi_raw[:, 0]
@@ -153,34 +153,21 @@ def check_derivatives(
153153
epsilon: float | None = 1e-3,
154154
check_least_squares: bool = True,
155155
skip_zero_pars: bool = False,
156+
skip_fields: list[str] | None = None,
156157
) -> None:
157158
"""
158159
Finite differences check for likelihood gradient.
159160
160-
:param model:
161-
amici model
162-
163-
:param solver:
164-
amici solver
165-
166-
:param edata:
167-
exp data
168-
169-
:param atol:
170-
absolute tolerance for comparison
171-
172-
:param rtol:
173-
relative tolerance for comparison
174-
175-
:param epsilon:
176-
finite difference step-size
177-
178-
:param check_least_squares:
179-
whether to check least squares related values.
180-
181-
:param skip_zero_pars:
182-
whether to perform FD checks for parameters that are zero
183-
161+
:param model: amici model
162+
:param solver: amici solver
163+
:param edata: exp data
164+
:param atol: absolute tolerance for comparison
165+
:param rtol: relative tolerance for comparison
166+
:param epsilon: finite difference step-size
167+
:param check_least_squares: whether to check least squares related values.
168+
:param skip_zero_pars: whether to perform FD checks for parameters that
169+
are zero
170+
:param skip_fields: list of fields to skip
184171
"""
185172
p = np.array(model.getParameters())
186173

@@ -200,6 +187,9 @@ def check_derivatives(
200187
solver.getSensitivityMethod() == SensitivityMethod.forward
201188
and solver.getSensitivityOrder() <= SensitivityOrder.first
202189
):
190+
if rdata.sx_ss is not None:
191+
fields.append("x_ss")
192+
203193
fields.append("x")
204194

205195
leastsquares_applicable = (
@@ -236,6 +226,8 @@ def check_derivatives(
236226
if edata is not None:
237227
fields.append("llh")
238228

229+
fields = [f for f in fields if f not in (skip_fields or [])]
230+
239231
# only check the sensitivities w.r.t. the selected parameters
240232
plist = model.getParameterList()
241233
if edata and edata.plist:

python/tests/test_pregenerated_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ def test_pregenerated_model(sub_test, case):
9696
and not model_name.startswith("model_neuron")
9797
and not case.endswith("byhandpreeq")
9898
):
99-
check_derivatives(model, solver, edata, **check_derivative_opts)
99+
check_derivatives(
100+
model,
101+
solver,
102+
edata,
103+
**check_derivative_opts,
104+
)
100105

101106
verify_simulation_opts = dict()
102107

src/rdata.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ void ReturnData::applyChainRuleFactorToSimulationResults(Model const& model) {
767767
chain_rule(ssigmaz, nztrue, nz, nmaxevent);
768768
chain_rule(srz, nztrue, nz, nmaxevent);
769769
chain_rule(sx0, nxtrue, nx, 1);
770+
chain_rule(sx_ss, nxtrue, nx, 1);
770771
}
771772

772773
if (o2mode == SecondOrderMode::full) {

0 commit comments

Comments
 (0)