Skip to content

Commit e864d87

Browse files
authored
Merge pull request #28 from JR-1991/fix-optimization
Fix `lmfit` optimization function
2 parents 260acca + 16a8715 commit e864d87

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

catalax/dataset/measurement.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ def to_jax_arrays(
295295
Raises:
296296
ValueError: If any state in state_order is missing from the measurement data.
297297
"""
298+
if len(self.data) == 0:
299+
raise ValueError(
300+
f"The measurement data is empty. Please add data to the measurement {self.id}"
301+
)
302+
298303
unused_states = [
299304
state for state in self.data.keys() if state not in state_order
300305
]
@@ -307,7 +312,7 @@ def to_jax_arrays(
307312

308313
if missing_states:
309314
raise ValueError(
310-
f"The measurement state are inconsistent with the dataset state. "
315+
f"The measurement states are inconsistent with the dataset states."
311316
f"Missing {missing_states} in measurement {self.id}"
312317
)
313318

catalax/tools/optimization.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,15 @@ def optimize(
5050
"""
5151

5252
params = _initialize_params(model, global_upper_bound, global_lower_bound)
53-
observables = jnp.array(
54-
[
55-
index
56-
for index, ode in enumerate(model.odes.values())
57-
if ode.observable is True
58-
]
59-
)
53+
dataset_observables = dataset.get_observable_states_order()
54+
observables = [
55+
index
56+
for index, state in enumerate(model.get_state_order())
57+
if state in dataset_observables
58+
]
6059

6160
# Extract data arrays for the residual computation
62-
data, times, _ = dataset.to_jax_arrays(model.get_observable_state_order())
61+
data, times, _ = dataset.to_jax_arrays(model.get_state_order())
6362

6463
# Create simulation config from dataset
6564
config = dataset.to_config()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import catalax as ctx
2+
3+
4+
class TestOptimization:
5+
def test_optimize_runs(self):
6+
# Create a Michaelis-Menten enzyme kinetics model
7+
model = ctx.Model(name="Enzyme Kinetics")
8+
9+
# Define states
10+
model.add_states(S="Substrate", P="Product")
11+
12+
# Add reaction kinetics via schemes
13+
model.add_reaction(
14+
"S -> P",
15+
symbol="r1",
16+
equation="v_max * S / (K_m + S)",
17+
)
18+
19+
# Set parameter values
20+
model.parameters["v_max"].value = 0.2
21+
model.parameters["v_max"].initial_value = 0.1
22+
model.parameters["K_m"].value = 0.1
23+
model.parameters["K_m"].initial_value = 0.05
24+
25+
# Create dataset and simulate
26+
dataset = ctx.Dataset.from_model(model)
27+
dataset.add_initial(S=1.0, P=0.0)
28+
config = ctx.SimulationConfig(t1=10, nsteps=10)
29+
dataset = model.simulate(dataset=dataset, config=config)
30+
31+
# Run optimization and check that it completes and returns expected types
32+
result, new_model = ctx.optimize(model, dataset)
33+
assert hasattr(result, "params")
34+
assert isinstance(new_model, type(model))

0 commit comments

Comments
 (0)