diff --git a/pyproject.toml b/pyproject.toml index 3b12182..ce72731 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "enzax" -version = "0.2.1" +version = "0.2.2" description = "Differentiable models of enzyme-catalysed reaction networks" authors = [ {name = "Teddy Groves", email = "tedgro@dtu.dk"}, @@ -12,10 +12,11 @@ dependencies = [ "arviz>=0.19.0", "equinox>=0.11.12", "python-libsbml>=5.20.4", - "sympy2jax>=0.0.5", + "sympy2jax>=0.0.7", "sbmlmath>=0.2.0", "jax>=0.5.2,<0.7.0", "typeguard>=2.13.3", + "requests>=2.32.3", ] requires-python = ">=3.12" readme = "README.md" diff --git a/src/enzax/rate_equations/mass_action.py b/src/enzax/rate_equations/mass_action.py new file mode 100644 index 0000000..325e309 --- /dev/null +++ b/src/enzax/rate_equations/mass_action.py @@ -0,0 +1,54 @@ +import equinox as eqx +from jax import numpy as jnp +import numpy as np +from numpy.typing import NDArray +from jaxtyping import PyTree, Scalar, Float, Array, ScalarLike + +from enzax.rate_equation import ConcArray, RateEquation +from enzax.rate_equations.thermodynamics import get_keq + + +class MassActionInput(eqx.Module): + kf: ScalarLike + dgf: Float[Array, " _"] + temperature: ScalarLike + ix_substrate: NDArray + ix_product: NDArray + + +class MassAction(RateEquation): + """A reaction with first order mass action kinetics.""" + + water_stoichiometry: float + + def get_input( + self, + parameters: PyTree, + reaction_id: str, + reaction_stoichiometry: NDArray[np.float64], + species_to_dgf_ix: NDArray[np.int16], + ): + ix_reactant = np.argwhere(reaction_stoichiometry != 0.0).flatten() + ix_substrate = np.argwhere(reaction_stoichiometry < 0.0).flatten() + ix_product = np.argwhere(reaction_stoichiometry > 0.0).flatten() + return MassActionInput( + kf=jnp.exp(parameters["log_kf"][reaction_id]), + ix_substrate=ix_substrate, + ix_product=ix_product, + dgf=parameters["dgf"][ix_reactant], + temperature=parameters["temperature"], + ) + + def __call__(self, conc: ConcArray, ma_input: PyTree) -> Scalar: + """Get the flux of a drain reaction.""" + + keq = get_keq( + ma_input.reaction_stoichiometry, + ma_input.dgf, + ma_input.temperature, + self.water_stoichiometry, + ) + kr = ma_input.kf / keq + return ma_input.kf * jnp.prod(conc[self.ix_substrate]) - kr * jnp.prod( + conc[self.ix_product] + ) diff --git a/src/enzax/rate_equations/michaelis_menten.py b/src/enzax/rate_equations/michaelis_menten.py index ca8581e..bfd6f3c 100644 --- a/src/enzax/rate_equations/michaelis_menten.py +++ b/src/enzax/rate_equations/michaelis_menten.py @@ -5,6 +5,7 @@ from numpy.typing import NDArray from enzax.rate_equation import RateEquation +from enzax.rate_equations.thermodynamics import get_reversibility class IrreversibleMichaelisMentenInput(eqx.Module): @@ -101,39 +102,6 @@ def numerator_mm( return jnp.prod((substrate_conc / substrate_kms)) -def get_reversibility( - reactant_conc: Float[Array, " n_reactant"], - dgf: Float[Array, " n_reactant"], - temperature: Scalar, - reactant_stoichiometry: NDArray[np.float64], - water_stoichiometry: float, -) -> Scalar: - """Get the reversibility of a reaction. - - Hard coded water dgf is taken from . - - The equation is - - 1 - exp(((dgr + (RT * quotient)) / RT)) - - but it's implemented a bit differently so as to be more numerically stable. - """ - RT = temperature * 0.008314 - conc_clipped = jnp.clip(reactant_conc, min=1e-9) - dgf_water = -150.9 - dgr_std = ( - reactant_stoichiometry.T @ dgf + water_stoichiometry * dgf_water - ).flatten() - quotient = jnp.clip( - reactant_stoichiometry.T @ jnp.log(conc_clipped), - min=-2e1, - max=2e1, - ).flatten() - expand = jnp.clip((dgr_std / RT) + quotient, min=-2.0, max=2.0) - out = -jnp.expm1(expand)[0] - return eqx.error_if(out, jnp.isnan(out), "Reversibility is nan!") - - def free_enzyme_ratio_imm( substrate_conc: Float[Array, " n_substrate"], substrate_km: Float[Array, " n_substrate"], diff --git a/src/enzax/rate_equations/thermodynamics.py b/src/enzax/rate_equations/thermodynamics.py new file mode 100644 index 0000000..3451744 --- /dev/null +++ b/src/enzax/rate_equations/thermodynamics.py @@ -0,0 +1,49 @@ +import equinox as eqx +import numpy as np +from jax import numpy as jnp +from jaxtyping import Array, Float, Scalar, ScalarLike +from numpy.typing import NDArray + + +def get_dgr_std(stoichiometry, dgf, temperature, water_stoichiometry): + RT = temperature * 0.008314 + dgf_water = -150.9 + dgr_std = ( + stoichiometry.T @ dgf + water_stoichiometry * dgf_water + ).flatten() + return jnp.exp(-dgr_std / RT) + + +def get_keq(stoichiometry, dgf, temperature: ScalarLike, water_stoichiometry): + minus_RT = -0.008314 * temperature + dgrs = get_dgr_std(stoichiometry, dgf, temperature, water_stoichiometry) + return jnp.exp(dgrs / minus_RT) + + +def get_reversibility( + reactant_conc: Float[Array, " n_reactant"], + dgf: Float[Array, " n_reactant"], + temperature: Scalar, + reactant_stoichiometry: NDArray[np.float64], + water_stoichiometry: float, +) -> Scalar: + """Get the reversibility of a reaction. + + Hard coded water dgf is taken from . + + The equation is + + 1 - exp(((dgr + (RT * quotient)) / RT)) + + but it's implemented a bit differently so as to be more numerically stable. + """ + RT = temperature * 0.008314 + conc_clipped = jnp.clip(reactant_conc, min=1e-9) + dgf_water = -150.9 + dgr_std = ( + reactant_stoichiometry.T @ dgf + water_stoichiometry * dgf_water + ).flatten() + quotient = (reactant_stoichiometry.T @ jnp.log(conc_clipped)).flatten() + expand = jnp.clip((dgr_std / RT) + quotient, min=-1e2, max=1e2) + out = -jnp.expm1(expand)[0] + return eqx.error_if(out, jnp.isnan(out), "Reversibility is nan!")