Skip to content

Conversation

@Aksh8t
Copy link

@Aksh8t Aksh8t commented Jun 21, 2025

Checklist

  • [NO] Appropriate tests were added
  • [YES] Any code changes were done in a way that does not break public API
  • [-] All documentation related to code changes were updated
  • [YES] The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • [-] Any new documentation only uses public API

Additional context

I have implemented the ODE-LSTM code of python to Julia in single file with all functions working right.

src/Odelstm.jl Outdated
Original paper: https://arxiv.org/abs/2006.04418
"""

module ODELSTM
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be done in a submodule

src/Odelstm.jl Outdated

module ODELSTM

using DifferentialEquations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just take the ODE solver from the user, no new deps needed here

src/Odelstm.jl Outdated
module ODELSTM

using DifferentialEquations
using Flux
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use Lux

src/Odelstm.jl Outdated

using DifferentialEquations
using Flux
using DiffEqFlux
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cannot import the module from within the same module

src/Odelstm.jl Outdated
Comment on lines 48 to 56
function get_solver(solver_type::Symbol)
solver_map = Dict(
:dopri5 => Tsit5(),
:tsit5 => Tsit5(),
:euler => Euler(),
:heun => Heun(),
:rk4 => RK4()
)
return get(solver_map, solver_type, Tsit5())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just take the solver instead of doing this

src/Odelstm.jl Outdated
end

function ODELSTMCell(input_size::Int, hidden_size::Int, solver_type::Symbol=:dopri5)
lstm_cell = Flux.LSTMCell(input_size => hidden_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using the recurrent model, which is not the algorithm and would break adaptivity

src/Odelstm.jl Outdated
Comment on lines 72 to 85
function solve_fixed_step(cell::ODELSTMCell, h, ts)
dt = ts / 3.0
h_evolved = h
for i in 1:3
if cell.solver_type == :euler
h_evolved = euler_step(cell.f_node, h_evolved, dt)
elseif cell.solver_type == :heun
h_evolved = heun_step(cell.f_node, h_evolved, dt)
elseif cell.solver_type == :rk4
h_evolved = rk4_step(cell.f_node, h_evolved, dt)
end
end
return h_evolved
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary: just use adaptive=false

@Aksh8t
Copy link
Author

Aksh8t commented Jun 23, 2025

Updated as per feedback, please review the changes

end
return results, st
else
t_span = (0.0f0, Float32(ts))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the types set by the user's input.

else
t_span = (0.0f0, Float32(ts))
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h, t_span)
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false)
sol = solve(prob, cell.solver, saveat=[t_span[2]])

redundant, since if saveat is used then it's false.


function solve_fixed_step(cell::ODELSTMCell, h, ts, p, st)
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h, (0.0f0, Float32(ts)))
sol = solve(prob, cell.solver; adaptive=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make completely separate, just allow a keyword argument in ODELSTMCell and just use the bool in the other function, delete the extra code.

for i in 1:batch_size
h_i = h[:, i]
ts_i = ts isa AbstractVector ? ts[i] : ts
t_span = (0.0f0, Float32(ts_i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense, the tspan can only be 2 values.

t_span = (0.0f0, Float32(ts_i))

prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h_i, t_span)
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to put saveat = ts?

export train_model!, evaluate_model, load_dataset

mutable struct ODELSTMCell{F,S}
lstm_cell::Lux.LSTMCell
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is actually the algorithm? Derive it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants