Skip to content

Commit 16a8715

Browse files
committed
Update optimization observable selection and add integration test
Refactored observable state selection in optimize() to use dataset observables and model state order for consistency. Added an integration test for the optimization workflow using a Michaelis-Menten model to verify correct execution and output types.
1 parent 94e247e commit 16a8715

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

catalax/tools/optimization.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,15 @@ def optimize(
5050
"""
5151

5252
params = _initialize_params(model, global_upper_bound, global_lower_bound)
53-
observables = jnp.array(
54-
[
55-
index
56-
for index, ode in enumerate(model.odes.values())
57-
if ode.observable is True
58-
]
59-
)
53+
dataset_observables = dataset.get_observable_states_order()
54+
observables = [
55+
index
56+
for index, state in enumerate(model.get_state_order())
57+
if state in dataset_observables
58+
]
6059

6160
# Extract data arrays for the residual computation
62-
data, times, _ = dataset.to_jax_arrays(model.get_observable_state_order())
61+
data, times, _ = dataset.to_jax_arrays(model.get_state_order())
6362

6463
# Create simulation config from dataset
6564
config = dataset.to_config()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import catalax as ctx
2+
3+
4+
class TestOptimization:
5+
def test_optimize_runs(self):
6+
# Create a Michaelis-Menten enzyme kinetics model
7+
model = ctx.Model(name="Enzyme Kinetics")
8+
9+
# Define states
10+
model.add_states(S="Substrate", P="Product")
11+
12+
# Add reaction kinetics via schemes
13+
model.add_reaction(
14+
"S -> P",
15+
symbol="r1",
16+
equation="v_max * S / (K_m + S)",
17+
)
18+
19+
# Set parameter values
20+
model.parameters["v_max"].value = 0.2
21+
model.parameters["v_max"].initial_value = 0.1
22+
model.parameters["K_m"].value = 0.1
23+
model.parameters["K_m"].initial_value = 0.05
24+
25+
# Create dataset and simulate
26+
dataset = ctx.Dataset.from_model(model)
27+
dataset.add_initial(S=1.0, P=0.0)
28+
config = ctx.SimulationConfig(t1=10, nsteps=10)
29+
dataset = model.simulate(dataset=dataset, config=config)
30+
31+
# Run optimization and check that it completes and returns expected types
32+
result, new_model = ctx.optimize(model, dataset)
33+
assert hasattr(result, "params")
34+
assert isinstance(new_model, type(model))

0 commit comments

Comments
 (0)