diff --git a/brainpy/integrators/joint_eq.py b/brainpy/integrators/joint_eq.py index 4d16b86a..55c1a3af 100644 --- a/brainpy/integrators/joint_eq.py +++ b/brainpy/integrators/joint_eq.py @@ -154,8 +154,18 @@ def __init__(self, *eqs): vars, _, _ = _get_args(eq) for var in vars: if var in vars_in_eqs: - raise DiffEqError(f'Variable "{var}" has been used, however we got a same ' - f'variable name in {eq}. Please change another name.') + raise DiffEqError( + f'Variable "{var}" has been used, however we got a same ' + f'variable name in {eq}.\n\n' + f'In JointEq, each state variable should appear as the first parameter ' + f'before "t" in exactly one derivative function. If "{var}" is a state ' + f'variable in another equation, it should be placed AFTER "t" in this ' + f'function as a dependency.\n\n' + f'Correct signature pattern:\n' + f' def d{var}({var}, t, ): ... # {var} is the state variable\n' + f' def dOther(other, t, {var}): ... # {var} is a dependency\n\n' + f'Current function signature: {inspect.signature(eq)}' + ) vars_in_eqs.extend(vars) self.vars_in_eqs.append(vars) diff --git a/brainpy/integrators/tests/test_joint_eq.py b/brainpy/integrators/tests/test_joint_eq.py index 9d9a7ba3..e617ab81 100644 --- a/brainpy/integrators/tests/test_joint_eq.py +++ b/brainpy/integrators/tests/test_joint_eq.py @@ -128,3 +128,46 @@ def test_nested_joint_eq1(self): EQ2 = JointEq(EQ1, dn) EQ3 = JointEq(EQ2, dV) print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.)) + + def test_second_order_ode(self): + """Test second-order ODE system (e.g., harmonic oscillator)""" + # Second-order ODE: d²x/dt² = -k*x - c*dx/dt + # Split into: dx/dt = v, dv/dt = -k*x - c*v + k = 1.0 # spring constant + c = 0.1 # damping + + def dx(x, t, v): + """dx/dt = v""" + return v + + def dv(v, t, x): + """dv/dt = -k*x - c*v""" + return -k * x - c * v + + # Create joint equation + eq = JointEq(dx, dv) + + # Test call + result = eq(x=1.0, v=0.0, t=0.0) + self.assertEqual(len(result), 2) + self.assertEqual(result[0], 0.0) # dx/dt = v = 0 + self.assertEqual(result[1], -k * 1.0) # dv/dt = -k*x + + def test_second_order_ode_wrong_signature(self): + """Test that wrong signature gives helpful error message""" + # WRONG: both x and v before t in dx function + def dx_wrong(x, v, t): + return v + + def dv(v, t, x): + return -x + + # This should raise an error with helpful message + with self.assertRaises(DiffEqError) as cm: + JointEq(dx_wrong, dv) + + # Check that error message is helpful + error_msg = str(cm.exception) + self.assertIn('state variable', error_msg.lower()) + self.assertIn('AFTER "t"', error_msg) + self.assertIn('dependency', error_msg.lower()) diff --git a/docs_classic/tutorial_toolbox/joint_equations.ipynb b/docs_classic/tutorial_toolbox/joint_equations.ipynb index 833a8e5b..5cc9be68 100644 --- a/docs_classic/tutorial_toolbox/joint_equations.ipynb +++ b/docs_classic/tutorial_toolbox/joint_equations.ipynb @@ -30,15 +30,10 @@ { "cell_type": "code", "id": "be08d171", - "metadata": { - "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.420059Z", - "start_time": "2025-10-06T05:15:55.413203Z" - } - }, + "metadata": {}, "source": "import brainpy as bp", "outputs": [], - "execution_count": 9 + "execution_count": null }, { "cell_type": "markdown", @@ -61,8 +56,8 @@ "id": "2921b856", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.429070Z", - "start_time": "2025-10-06T05:15:55.423341Z" + "end_time": "2025-11-03T08:47:30.526844Z", + "start_time": "2025-11-03T08:47:30.524677Z" } }, "source": [ @@ -71,7 +66,7 @@ "du = lambda u, t, V: a * (b * V - u)" ], "outputs": [], - "execution_count": 10 + "execution_count": 2 }, { "cell_type": "markdown", @@ -86,15 +81,15 @@ "id": "08ac3b75", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.435687Z", - "start_time": "2025-10-06T05:15:55.432079Z" + "end_time": "2025-11-03T08:47:36.337738Z", + "start_time": "2025-11-03T08:47:36.334856Z" } }, "source": [ "joint_eq = bp.JointEq(dV, du)" ], "outputs": [], - "execution_count": 11 + "execution_count": 3 }, { "cell_type": "markdown", @@ -109,15 +104,15 @@ "id": "356cf60d", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.443030Z", - "start_time": "2025-10-06T05:15:55.439434Z" + "end_time": "2025-11-03T08:47:38.014070Z", + "start_time": "2025-11-03T08:47:38.009818Z" } }, "source": [ "itg = bp.odeint(joint_eq, method='rk2')" ], "outputs": [], - "execution_count": 12 + "execution_count": 4 }, { "cell_type": "markdown", @@ -152,8 +147,8 @@ "id": "4dec7537", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.450726Z", - "start_time": "2025-10-06T05:15:55.447321Z" + "end_time": "2025-11-03T08:47:40.678739Z", + "start_time": "2025-11-03T08:47:40.675701Z" } }, "source": [ @@ -165,7 +160,7 @@ "itg_V_u = bp.odeint(diff, method='rk2')" ], "outputs": [], - "execution_count": 13 + "execution_count": 5 }, { "cell_type": "markdown", @@ -180,8 +175,8 @@ "id": "12e5d88d", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.460116Z", - "start_time": "2025-10-06T05:15:55.455524Z" + "end_time": "2025-11-03T08:47:42.491024Z", + "start_time": "2025-11-03T08:47:42.487721Z" } }, "source": [ @@ -189,7 +184,7 @@ "int_u = bp.odeint(du, method='rk2')" ], "outputs": [], - "execution_count": 14 + "execution_count": 6 }, { "cell_type": "markdown", @@ -206,8 +201,8 @@ "id": "38101bec", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.473617Z", - "start_time": "2025-10-06T05:15:55.467195Z" + "end_time": "2025-11-03T08:47:43.649877Z", + "start_time": "2025-11-03T08:47:43.643994Z" } }, "source": "bp.odeint(dV, method='rk2', show_code=True)", @@ -216,7 +211,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "def brainpy_itg_of_ode10(V, t, u, Iext, dt=0.1):\n", + "def brainpy_itg_of_ode4(V, t, u, Iext, dt=0.1):\n", " dV_k1 = f(V, t, u, Iext)\n", " k2_V_arg = V + dt * dV_k1 * 0.6666666666666666\n", " k2_t_arg = t + dt * 0.6666666666666666\n", @@ -224,22 +219,22 @@ " V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75\n", " return V_new\n", "\n", - "{'f': at 0x0000015BBFC96C00>}\n", + "{'f': at 0x12ef725c0>}\n", "\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 15, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 15 + "execution_count": 7 }, { "cell_type": "markdown", @@ -256,8 +251,8 @@ "id": "32901ae6", "metadata": { "ExecuteTime": { - "end_time": "2025-10-06T05:15:55.494272Z", - "start_time": "2025-10-06T05:15:55.488118Z" + "end_time": "2025-11-03T08:47:47.051374Z", + "start_time": "2025-11-03T08:47:47.045364Z" } }, "source": [ @@ -269,7 +264,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "def brainpy_itg_of_ode11_joint_eq(V, u, t, Iext, dt=0.1):\n", + "def brainpy_itg_of_ode5_joint_eq(V, u, t, Iext, dt=0.1):\n", " dV_k1, du_k1 = f(V, u, t, Iext)\n", " k2_V_arg = V + dt * dV_k1 * 0.6666666666666666\n", " k2_u_arg = u + dt * du_k1 * 0.6666666666666666\n", @@ -279,22 +274,22 @@ " u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75\n", " return V_new, u_new\n", "\n", - "{'f': }\n", + "{'f': }\n", "\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 16 + "execution_count": 8 }, { "cell_type": "markdown", @@ -304,6 +299,207 @@ "It is shown in this output code that second differential values of $v$ and $u$ are calculated by using the updated values (`k2_V_arg` and `k2_u_arg`) at the same time. This will result in a more accurate integral." ] }, + { + "cell_type": "markdown", + "id": "second_order_ode_title", + "metadata": {}, + "source": [ + "## Second-Order ODEs with `brainpy.JointEq`\n", + "\n", + "A common use case for `JointEq` is solving second-order ordinary differential equations (ODEs). ", + "Second-order ODEs appear in many physical systems, such as the harmonic oscillator, pendulum, ", + "or neural mass models like the Jansen-Rit model.\n", + "\n", + "When using `JointEq` for second-order ODEs, it's important to follow the correct function signature pattern." + ] + }, + { + "cell_type": "markdown", + "id": "second_order_example", + "metadata": {}, + "source": [ + "### Example: Harmonic Oscillator\n", + "\n", + "Consider a damped harmonic oscillator described by:\n", + "\n", + "$$\\frac{d^2x}{dt^2} = -kx - c\\frac{dx}{dt}$$\n", + "\n", + "To solve this with `JointEq`, we split it into two first-order ODEs:\n", + "\n", + "$$\\frac{dx}{dt} = v$$\n", + "$$\\frac{dv}{dt} = -kx - cv$$\n", + "\n", + "Where $x$ is position and $v$ is velocity." + ] + }, + { + "cell_type": "code", + "id": "second_order_code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-03T08:47:50.381208Z", + "start_time": "2025-11-03T08:47:50.377324Z" + } + }, + "source": [ + "import brainpy as bp\n", + "import brainpy.math as bm\n", + "\n", + "# Parameters\n", + "k = 1.0 # spring constant\n", + "c = 0.1 # damping coefficient\n", + "\n", + "# Define derivative functions\n", + "# IMPORTANT: Each state variable appears as the FIRST parameter before 't'\n", + "# Other state variables appear AFTER 't' as dependencies\n", + "def dx(x, t, v):\n", + " \"\"\"dx/dt = v\"\"\"\n", + " return v\n", + "\n", + "def dv(v, t, x):\n", + " \"\"\"dv/dt = -k*x - c*v\"\"\"\n", + " return -k * x - c * v\n", + "\n", + "# Create joint equation\n", + "joint_eq = bp.JointEq(dx, dv)\n", + "print(f\"Joint equation signature: {joint_eq.__signature__}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Joint equation signature: (x, v, t)\n" + ] + } + ], + "execution_count": 9 + }, + { + "cell_type": "markdown", + "id": "second_order_signature", + "metadata": {}, + "source": [ + "### Important: Function Signature Pattern\n", + "\n", + "When defining derivative functions for `JointEq`, follow this pattern:\n", + "\n", + "**Correct:**\n", + "```python\n", + "def dx(x, t, v): # x is the state variable, v is a dependency\n", + " return v\n", + "\n", + "def dv(v, t, x): # v is the state variable, x is a dependency\n", + " return -k * x - c * v\n", + "```\n", + "\n", + "**Incorrect:**\n", + "```python\n", + "def dx(x, v, t): # WRONG: Both x and v before t\n", + " return v\n", + "```\n", + "\n", + "**Rule:** Each state variable should appear as the **first parameter before** `t` in exactly one derivative function. ", + "If a variable is needed as a dependency in another function, it should be placed **after** `t`.\n", + "\n", + "This ensures that `JointEq` knows which variable each function is differentiating and which variables are dependencies." + ] + }, + { + "cell_type": "markdown", + "id": "second_order_jansen_rit", + "metadata": {}, + "source": [ + "### Example: Jansen-Rit Model\n", + "\n", + "The Jansen-Rit model is a neural mass model with three coupled second-order ODEs. ", + "Here's how to implement it correctly with `JointEq`:" + ] + }, + { + "cell_type": "code", + "id": "jansen_rit_code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-03T08:47:55.398154Z", + "start_time": "2025-11-03T08:47:55.348911Z" + } + }, + "source": [ + "class JansenRitModel(bp.dyn.NeuDyn):\n", + " def __init__(self, size=1, A=3.25, te=10, B=22, ti=20, C=135, \n", + " e0=2.5, r=0.56, v0=6, method='rk4', **kwargs):\n", + " super().__init__(size=size, **kwargs)\n", + " self.A, self.te = A, te\n", + " self.B, self.ti = B, ti\n", + " self.C = C\n", + " self.e0, self.r, self.v0 = e0, r, v0\n", + " \n", + " # State variables: positions (y0, y1, y2) and velocities (y3, y4, y5)\n", + " self.y0 = bm.Variable(bm.zeros(self.num))\n", + " self.y1 = bm.Variable(bm.zeros(self.num))\n", + " self.y2 = bm.Variable(bm.zeros(self.num))\n", + " self.y3 = bm.Variable(bm.zeros(self.num)) # velocity for y0\n", + " self.y4 = bm.Variable(bm.zeros(self.num)) # velocity for y1\n", + " self.y5 = bm.Variable(bm.zeros(self.num)) # velocity for y2\n", + " \n", + " self.integral = bp.odeint(f=self.derivative, method=method)\n", + " \n", + " # Position derivatives: dx/dt = v\n", + " def dy0(self, y0, t, y3): # y0 is state, y3 is dependency\n", + " return y3 / 1000\n", + " \n", + " def dy1(self, y1, t, y4): # y1 is state, y4 is dependency\n", + " return y4 / 1000\n", + " \n", + " def dy2(self, y2, t, y5): # y2 is state, y5 is dependency\n", + " return y5 / 1000\n", + " \n", + " # Velocity derivatives: dv/dt = ...\n", + " def dy3(self, y3, t, y0, y1, y2): # y3 is state, others are dependencies\n", + " Sp = 2 * self.e0 / (1 + bm.exp(self.r * (self.v0 - y1 + y2)))\n", + " return (self.A * Sp - 2 * y3 - y0 / self.te * 1000) / self.te\n", + " \n", + " def dy4(self, y4, t, y0, y1, inp=0.): # y4 is state, others are dependencies\n", + " Se = 2 * self.e0 / (1 + bm.exp(self.r * (self.v0 - self.C * y0)))\n", + " return (self.A * (inp + 0.8 * self.C * Se) - 2 * y4 - y1 / self.te * 1000) / self.te\n", + " \n", + " def dy5(self, y5, t, y0, y2): # y5 is state, others are dependencies\n", + " Si = 2 * self.e0 / (1 + bm.exp(self.r * (self.v0 - 0.25 * self.C * y0)))\n", + " return (self.B * 0.25 * self.C * Si - 2 * y5 - y2 / self.ti * 1000) / self.ti\n", + " \n", + " @property\n", + " def derivative(self):\n", + " # Join all derivatives - order matches the state variables\n", + " return bp.JointEq([self.dy0, self.dy1, self.dy2, self.dy3, self.dy4, self.dy5])\n", + " \n", + " def update(self, inp=0.):\n", + " y0, y1, y2, y3, y4, y5 = self.integral(\n", + " self.y0, self.y1, self.y2, self.y3, self.y4, self.y5,\n", + " bp.share['t'], inp, bp.share['dt']\n", + " )\n", + " self.y0.value = y0\n", + " self.y1.value = y1\n", + " self.y2.value = y2\n", + " self.y3.value = y3\n", + " self.y4.value = y4\n", + " self.y5.value = y5\n", + "\n", + "# Create and test the model\n", + "model = JansenRitModel(size=1)\n", + "print(\"Jansen-Rit model created successfully!\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Jansen-Rit model created successfully!\n" + ] + } + ], + "execution_count": 10 + }, { "cell_type": "markdown", "id": "73051bec",