Skip to content
Discussion options

You must be logged in to vote

I think the issue you have is that use of static_argnums with forwards_backwards. This re-compiles the function whenever the static input changes, which I think will change every time you update your instance of jax_MLP.

What I would recommend is to follow the philosophy followed by haiku and similar packages, where the parameters are passed around explicitly in the optimization loop, instead of updating a model object. This follows the jax philosophy of making functions functional.

Explicitly what I would recommend is changing update to

def update(params,grads,lr):
  a, b, w = params
  da,db,dw = grads[0]
  a -= lr*da
  b -= lr*db
  w -= lr*dw
  return a, b, w

unjitting forwards_backwards

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@LawrenceMMStewart
Comment options

@C-J-Cundy
Comment options

@LawrenceMMStewart
Comment options

Answer selected by LawrenceMMStewart
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants