Skip to content
Discussion options

You must be logged in to vote

So, there are ways you can do this currently by defining a custom primitive and using non-public routines which are not guaranteed to have a stable API. My hope is that we can make this cleaner and more intuitive in the future, but as a proof of concept here is some code that works in the most recent release.

Here is an example of defining a new primitive, np_sin, which computes the element-wise sine of an input using a callback to numpy.sin. In addition, I define a lowering rule (for compatibility with jit), a batching rule (for compatibility with vmap), and a jvp rule (for compatibility with grad and other autodiff transformations):

import numpy as np
import jax.numpy as jnp

import jax
f…

Replies: 2 comments 6 replies

Comment options

You must be logged in to vote
6 replies
@lucasgrjn
Comment options

@jakevdp
Comment options

@lucasgrjn
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants