-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[ROCm] Implement RNN support #25755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Implement RNN support #25755
Conversation
|
@dfm and @superbobry could you please take a look? |
0b07837 to
36d037e
Compare
superbobry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dfm want to have a look as well?
dfm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good overall - thanks! My main high level comment is that it would be useful to move as much of the #ifdef JAX_GPU_HIP logic into vendor.h rather than in rnn_kernels.cc directly. It's ok to have some, but the more we can move, the better. Can you look into redefining some of the macros in vendor.h to consolidate the logic there?
jax/experimental/rnn.py
Outdated
| mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda') | ||
| mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since gpu_rnn is in jaxlib, these changes will cause problems with version skew. JAX always needs to work with the most recent stable release of jaxlib. Perhaps you could protect this using hasattr(gpu_rnn, "miopen_rnn_fwd_lowering")?
jax/experimental/rnn.py
Outdated
| mlir.register_lowering( | ||
| rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, this needs to be protected against old version of jaxlib.
2e86003 to
18cc2d2
Compare
|
@dfm still not sure why this error wouldn't go away. I have protections in place. Probably it is how you test this in your internal CI? Seems like you are getting the jaxlib from upstream and that is why the related tests fail?
|
@dfm thanks. I see what you mean. However, miopen apis are quiet different from cudnn. For e.g. I checked to see how many of |
Yes! We require that Also: It looks like this has introduced some build issues for the CUDA CI. Can you take a look at those too? |
a909942 to
dfd1a65
Compare
dfd1a65 to
fe68eb8
Compare
|
@dfm I just fixed the patch. Could you please approve? thanks! |


Created from: ROCm#171