Skip to content
Discussion options

You must be logged in to vote

In general, jax.numpy and jax.scipy are built on top of jax.lax functions, and they are merely syntactic sugar for jax.lax functions that are usually a bit harder to use and have stricter syntax.
Also, they mimick the scipy/numpy api which is more widely known and in general easier to use.

If you inspect their source code, both jax.scipy.signal.convolve/2d call into jax.lax.conv_general_dilated : https://jax.readthedocs.io/en/latest/_modules/jax/_src/scipy/signal.html#_convolve_nd

Flax, haiku and objax are frameworks helpful to construct and use a neural network, but don't really provide any basic functionality.
They are all built on top of bare jax functions (in particular, jax.lax.conv_…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by EngineerKhan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants