Replies: 1 comment 2 replies
-
See JAX Sharp Bits: Control Flow. There are a few ways you could express this (including def const_Tc(Ts, T0, pars, k):
A = pars[0]
b = pars[1]
C = pars[2]
n = pars[3]
s0 = pars[4]
mu = 1
As = A*k*Ts
Cs = C*k*Ts
intfac = 1/mu * jnp.power(Cs,n)*Ts/(n+1)*jnp.power(T0/Ts-s0,n+1)
bnfr = (b+1)/(n+1)
res = jnp.zeros(4)
res = res.at[0].set(1/mu * jnp.power(As*jnp.maximum(T0/Ts - s0, 0.0),b))
res = res.at[1].set(1/mu * jnp.power(Cs*jnp.maximum(T0/Ts - s0, 0.0),n))
res = res.at[2].set(jnp.exp(-res[1] * T0 + intfac) *jnp.power(As,b) / jnp.power(Cs,n*bnfr) * jnp.power(mu/Ts * (n+1),(b-n)/(n+1)) * jsp.special.gammainc(bnfr,intfac)*gamma(bnfr))
res = res.at[3].set(jnp.exp(-res[1] * T0))
Tc = -1/res[1] * jnp.log((res[0]/res[1]*res[3] + res[2]) / (1 + res[0]/res[1]))
return jnp.where(T0/Ts-s0>0, Tc, 0) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
My code works if
It returns nan if
The reason is that my T0/Ts-s0 < 0 and it returns nan.
I hope my const_Tc() return 0 if T0/Ts-s0 < 0. But how can I write it in jax?
This is what I really want to have. But how can I write it use lax.cond() or jnp.select()?
Beta Was this translation helpful? Give feedback.
All reactions