Replies: 2 comments
-
I revise my code to avoid "if" using lax.cond and jnp. where. But I meet new error~crying. def compare_loss(loss, max_loss):
return jnp.where((loss-max_loss)>0,1,0)
is_update_grad = compare_loss(loss, max_loss)
max_loss, max_params = jax.lax.cond(is_update_grad, true_fun1(loss, noised_model, max_loss, max_params),false_fun1(max_loss, max_params))
Can you help me? Thanks |
Beta Was this translation helpful? Give feedback.
0 replies
-
Hi - thanks for the question! There are two options that may work for you depending on the context:
If you need more specific advice, I'd suggest sharing a minimal reproducible example of the code that leads to the error you're seeing. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I use the pamp in my code, but I need to use a judgment statement to achieve the comparsion of two numbers. But the error is jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/2)>
I know it is related to jit complie because I use the pmap function. I try a lot way, such as static argnums, control flow, and so on, but I still failed.
I hope that I can obtain the help. Thank you!!
Beta Was this translation helpful? Give feedback.
All reactions