Replies: 1 comment 5 replies
-
Have you managed to solve this? Is there a way to bypass? |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I am new to jax and have a probably noobish question:
I have a jax function that maps a real array to a complex scalar (it will represent a wavefunction)
I was able to implement its laplacian like this:
laplace_psi = lambda x: jnp.diag( jacobian( jacobian(psi, holomorphic = True ), holomorphic = True) (x) ).sum(-1)
It does not work without the holomorphism promises because of the complex output, so I am forced to input complex numbers, and it produces the correct result only when the function is in fact holomorphic.
Is there a way to get the Laplacian of a non-holomorphic R^n->C function?
The crux of my predicament (it is not a bug, it is expected behavior explained in the documentation, just not the behavior I need) is illustrated in this code:
The two versions of psi produce the same result as long as x is on the real line. As C->C functions they are however completely different, and psi2 is not holomorphic. Therefore in jax's convention for gradients of complex functions they have different derivatives.
Is there a way to "look on them" as R->C functions and get the same grads?
Beta Was this translation helpful? Give feedback.
All reactions