Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions brainpy/integrators/joint_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,18 @@ def __init__(self, *eqs):
vars, _, _ = _get_args(eq)
for var in vars:
if var in vars_in_eqs:
raise DiffEqError(f'Variable "{var}" has been used, however we got a same '
f'variable name in {eq}. Please change another name.')
raise DiffEqError(
f'Variable "{var}" has been used, however we got a same '
f'variable name in {eq}.\n\n'
f'In JointEq, each state variable should appear as the first parameter '
f'before "t" in exactly one derivative function. If "{var}" is a state '
f'variable in another equation, it should be placed AFTER "t" in this '
f'function as a dependency.\n\n'
f'Correct signature pattern:\n'
f' def d{var}({var}, t, <dependencies>): ... # {var} is the state variable\n'
f' def dOther(other, t, {var}): ... # {var} is a dependency\n\n'
f'Current function signature: {inspect.signature(eq)}'
)
vars_in_eqs.extend(vars)
self.vars_in_eqs.append(vars)

Expand Down
43 changes: 43 additions & 0 deletions brainpy/integrators/tests/test_joint_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,46 @@ def test_nested_joint_eq1(self):
EQ2 = JointEq(EQ1, dn)
EQ3 = JointEq(EQ2, dV)
print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.))

def test_second_order_ode(self):
"""Test second-order ODE system (e.g., harmonic oscillator)"""
# Second-order ODE: d²x/dt² = -k*x - c*dx/dt
# Split into: dx/dt = v, dv/dt = -k*x - c*v
k = 1.0 # spring constant
c = 0.1 # damping

def dx(x, t, v):
"""dx/dt = v"""
return v

def dv(v, t, x):
"""dv/dt = -k*x - c*v"""
return -k * x - c * v

# Create joint equation
eq = JointEq(dx, dv)

# Test call
result = eq(x=1.0, v=0.0, t=0.0)
self.assertEqual(len(result), 2)
self.assertEqual(result[0], 0.0) # dx/dt = v = 0
self.assertEqual(result[1], -k * 1.0) # dv/dt = -k*x

def test_second_order_ode_wrong_signature(self):
"""Test that wrong signature gives helpful error message"""
# WRONG: both x and v before t in dx function
def dx_wrong(x, v, t):
return v

def dv(v, t, x):
return -x

# This should raise an error with helpful message
with self.assertRaises(DiffEqError) as cm:
JointEq(dx_wrong, dv)

# Check that error message is helpful
error_msg = str(cm.exception)
self.assertIn('state variable', error_msg.lower())
self.assertIn('AFTER "t"', error_msg)
self.assertIn('dependency', error_msg.lower())
Loading
Loading