Skip to content
Discussion options

You must be logged in to vote

Could you paste your code in the question rather than linking to a zip file? Thanks!

So there's a lot in your code and I don't think I'll have a chance to debug it in detail, but one red flag to me is your reliance on Python loops. Particularly this one:

def max_pos(x, num_elems):
    sum = 0.0

    for i in range(num_elems):
        sum += jnp.exp(x[i])

    return jnp.log(sum)

Due to accumulation of floating-point roundoff error, I wouldn't be surprised if this is what is causing your zero gradient. In general, if you are using JAX (or numpy) and find yourself looping over array values, you will find better and faster results by using built-in array-oriented functions. So, for example, …

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@mjhoover1
Comment options

@jakevdp
Comment options

@mjhoover1
Comment options

@hawkinsp
Comment options

@mjhoover1
Comment options

Answer selected by mjhoover1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants