Skip to content
Discussion options

You must be logged in to vote

Yes, toolz.pipe should work correctly with JAX; for example:

import jax
import jax.numpy as jnp
import toolz

fun = jax.jit(
    lambda x: toolz.pipe(x,
      jnp.sin,
      lambda x: x + 1,
      jnp.exp))

x = jnp.arange(4)
print(fun(x))
# [2.7182817 6.305807  6.7483463 3.1302722]

If you hit any situations where it doesn't work, feel free to open an issue and we'll see if there's anything we can do on the JAX side to make it work!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jakubMitura14
Comment options

Answer selected by jakubMitura14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants