Skip to content
Discussion options

You must be logged in to vote

Well, I'll explain it myself. The translation rule of device_put is actually lambda c, x, device=None : x in jax/interpreters/xla.py#L1342, so once jit or pmap(which also applies jit) is applied to a function, the device_put inside does not work.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@skye
Comment options

skye Mar 19, 2021
Maintainer

Answer selected by ZYHowell
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants