File tree Expand file tree Collapse file tree 1 file changed +14
-14
lines changed
Expand file tree Collapse file tree 1 file changed +14
-14
lines changed Original file line number Diff line number Diff line change @@ -96,21 +96,21 @@ def test_jax_llh(benchmark_problem):
9696 r_amici = simulate_amici ()
9797 llh_amici = r_amici [LLH ]
9898
99- jax_problem = import_petab_problem (
100- petab_problem ,
101- output_dir = benchmark_outdir / (problem_id + "_jax" ),
102- jax = True ,
103- )
104- if problem_parameters :
105- jax_problem = eqx .tree_at (
106- lambda x : x .parameters ,
107- jax_problem ,
108- jnp .array (
109- [problem_parameters [pid ] for pid in jax_problem .parameter_ids ]
110- ),
99+ try :
100+ jax_problem = import_petab_problem (
101+ petab_problem ,
102+ output_dir = benchmark_outdir / (problem_id + "_jax" ),
103+ jax = True ,
111104 )
105+ if problem_parameters :
106+ jax_problem = eqx .tree_at (
107+ lambda x : x .parameters ,
108+ jax_problem ,
109+ jnp .array (
110+ [problem_parameters [pid ] for pid in jax_problem .parameter_ids ]
111+ ),
112+ )
112113
113- try :
114114 if problem_id in problems_for_gradient_check :
115115 if problem_id == "Weber_BMC2015" :
116116 atol = cur_settings .atol_sim
@@ -154,6 +154,6 @@ def test_jax_llh(benchmark_problem):
154154 except (NotImplementedError , TypeError ) as err :
155155 if "run_simulations does not support PEtab v1 problems" in str (err ):
156156 pytest .skip (str (err ))
157- if "The JAX backend does not support simultaneous events" in str (err ):
157+ elif "The JAX backend does not support simultaneous events" in str (err ):
158158 pytest .skip (str (err ))
159159 raise err
You can’t perform that action at this time.
0 commit comments