Branching behaviour depending on device used. #23716
Unanswered
AdrienCorenflos
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
What is the preferred way for a user to branch behaviour based on which kind of device the input is located?
To ground the question, say I am implementing the following trinket function:
for which I know the values
cs
andus
are sorted.As per the documentation of
searchsorted
So I'd like to automatically use
jnp.searchsorted(..., method="scan")
on CPU andjnp.searchsorted(..., method="sort")
otherwise.I am seeing (although not understanding yet how I would use it) that JAX uses
mlir
primitive CPU/GPU lowering in the library's lower levels.Is this the preferred end-user interface too or are there simpler alternatives?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions