Replies: 1 comment 5 replies
-
Sorry to be dense. Could you give a general mathematical formula for your problem? import jax
import jax.numpy as jnp
w = jnp.ones((50, 1))
def f(x):
assert x.shape == (1,)
return w @ x
xs = jnp.ones((1000, 1))
grads = jax.vmap(jax.jacfwd(f))(xs)
assert grads.shape == (1000, 50, 1) |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, apologies if this isn't the best place to ask this. If it's not, I'd really appreciate if someone could suggest a platform where I might ask this question. For my work, I mainly look at function approximation and PDE approximation problems using NNs and I actually need to compute quantities like (for simplicity, considering a network with a single hidden layer where W_i and b_i are the ith components of the weight and bias)
d/dx sigma(x * W_i + b_i)
for all i, evaluated for some array x. I could never find a satisfactory/efficient way to do this in tensorflow, and I know this isn't a common thing that most users need to do.
When there's only a single hidden layer, it's easy enough for me to do this manually without auto differentiation. Things obviously get pretty sloppy with more than one layer so I was hoping to leverage jax to do this for me. I've included a hopefully minimal example below that just has two hidden layers and one-dimensional input with the same number of neurons in each layer. The main issue is that this ends up being quite slow, especially if the array
x
is large and/or the width of each layer is relatively large. I need to more or less do this computation for each epoch. Is there a more efficient way for me to be doing this computation with jax?Here's a quick vectorized version using
vmap
if it helps anyone.Beta Was this translation helpful? Give feedback.
All reactions