Skip to content
Discussion options

You must be logged in to vote

@marcofrancis I tried to simplify things a bit and it seems to work for me.

γ = 1.5
k = 0.1
μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ

# PDE params.
σω = σ

dt =0.01

f = lambda ω: jnp.exp(-(1-γ)*ω)

f_x= jax.grad(f) #first derivative
f_xx= jax.grad(jax.grad(f))#second derivative
f_next = lambda ω: f(ω) + 100*dt * (
             (0.5*σω**2)*f_xx(ω) - λ*(ω-ωb)*f_x(ω) 
                + (1-k)*f(ω))
print(f_next(0.))
print(jax.grad(f_next)(0.)) #first derivative
1.9151125
0.9075562

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by marcofrancis
Comment options

You must be logged in to vote
2 replies
@soraros
Comment options

@marcofrancis
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants