Reduce Compilation Times for Numerical Hilbert Transform #9487
-
Hello - I am working on a way to propagate distributions through neural networks (NNs) through using Characteristic functions (CFs). Long story short, in every layer of the NN, the hilbert transform needs to be applied to this function, which uses sinc and sine functions and sums them all up to compute the numerical hilbert transform. See code below as an example of a propagation step:
My issue is that this hilbert transform step takes a very long time to compile, especially for high-resolution grids (high values of M and L). I am trying to use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Essentially any time you implement an algorithm with an explicit loop, you should expect compilation and execution times to be slower than necessary. I think the main bottleneck here is your use of a loop within def hilbertTransform(f, grid, x, hilb_grid, h):
eval_pt = (x - hilb_grid * h) / h
return jnp.sum(jnp.interp(hilb_grid * h, grid, f) * jnp.sinc(eval_pt / 2) * jnp.sin(eval_pt / 2)) Also, you should delete the I think your other loops could probably also be replaced with broadcasting operations as well, but since |
Beta Was this translation helpful? Give feedback.
Essentially any time you implement an algorithm with an explicit loop, you should expect compilation and execution times to be slower than necessary.
I think the main bottleneck here is your use of a loop within
hilbertTransform
. Instead, you should use broadcasting:Also, you should delete the
block_until_ready()
calls in your functions, as they are preventing asynchronous dispatch from speeding up your loops.I think your other loops could probably also be replaced with broadcasting operations as well,…