How to implement in-place logic in a CUDA custom call ? #19261
Unanswered
Dong-Jiahuan
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I want to register a CUDA function into JAX and this CUDA function will do something to an operand which passed from python.
For an example , in python, I use custom_call('my_fun', operands=[BUFFER], ...). So, in C++, my_fun receives
( cudaStream_t stream,
void **buffers,
const char *opaque,
std::size_t opaque_len)
I get BUFFER from buffers[0] and buffers[1] is the output.
My CUDA function wants do something to BUFFER and set BUFFER as the last element of buffers. (I think this is where the problem lies in.)
This example is not my real case and I didn't test this example. In real case, BUFFER is one of many operands and I still want to implement in-place logic to BUFFER.
After I register my real CUDA function to JAX and use it, I got an error :
Beta Was this translation helpful? Give feedback.
All reactions