-
I have a simple MLP with one input and one output and would like to calculate the first and second order derivatives of the individual outputs with respect to the respective input. Basically I'd like to differentiate the following function twice and evaluate the derivatives at def physics_pred(x_physics, params):
y_physics = net.apply(params, x_physics)
return y_physics I tried using In the end I'd like to use these derivatives in a physics-informed loss: def loss_physics(params: hk.Params, x_data: jnp.array, y_data: jnp.array, x_physics: jnp.array):
y_pred_data = net.apply(params, x_data)
data_loss = jnp.mean((y_pred_data - y_data)**2)
# Function to differentiate
def physics_pred(x_physics, params):
y_physics = net.apply(params, x_physics)
return y_physics
y_pred_physics = physics_pred(x_physics, params)
df_dx = # calculate first order derivatives
df_dx2 = # calculate second order derivatives
residual = df_dx2 + mu * df_dx + k * y_pred_physics
physics_loss = (1e-4) * jnp.mean(residual**2)
return data_loss + physics_loss Edit: Here is a minimal working example of what I'd like to achieve: from jax import grad, vmap
xs = jnp.linspace(0, 1, 10)
def y(x):
return x**2
dx = vmap(grad(y))(xs)
dx2 = vmap(grad(grad(y)))(xs) In this example the function Edit 3: This solution seems to work but I am not sure if it is the most efficient one: def loss_physics(params: hk.Params, x_data: jnp.array, y_data: jnp.array, x_physics: jnp.array):
y_pred_data = net.apply(params, x_data)
data_loss = jnp.mean((y_pred_data - y_data)**2)
# The solution to the differential equation is represented by our network
u = lambda x: net.apply(params, x)
# Calculate first and second derivates of network
u_dx = lambda x: jax.grad(lambda x: jnp.sum(u(x)))(x)
u_dx2 = lambda x: jax.grad(lambda x: jnp.sum(u_dx(x)))(x)
# Compute physical loss
y_pred_physics = net.apply(params, x_physics)
residual = u_dx2(x_physics) + mu * u_dx(x_physics) + k * y_pred_physics
physics_loss = (1e-4) * jnp.mean(residual**2) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Can you add more information about the shape of the inputs and outputs of It would be most useful if you could edit your question to add a minimal reproducible example, such that we could run your code and see the same outputs that you're seeing, rather than just guessing at what your function might do. Note the "minimal" here does not imply giving us your entire neural net; it might be sufficient for the sake of the question to replace |
Beta Was this translation helpful? Give feedback.
Can you add more information about the shape of the inputs and outputs of
physics_pred
? It's also not clear to me what kind of derivative you're interested in: do you want the element-wise[dy/dx0, dy/dx1...]
, or ify
is a vector do you want the matrix of derivatives{dyi/dxj}
? Or, given thaty
might be a vector of the same length asx
, do you want element-wise derivatives[dy0/dx0, dy1/dx1,...]
?It would be most useful if you could edit your question to add a minimal reproducible example, such that we could run your code and see the same outputs that you're seeing, rather than just guessing at what your function might do. Note the "minimal" here does not imply giving us your entire neura…