-
Notifications
You must be signed in to change notification settings - Fork 603
Open
Description
Is your feature request related to a problem? Please describe.
Basically, cudnn attention doesn't work when vmapped because of this assertion here:
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions/attention.py#L563
This limits its usability when one wants to implement pipeline parallelism as it requires vmapping staged layers over stage buffer.
Describe the solution you'd like
Properly implemented vmap rules for the cudnn attention.
coderabbitai
Metadata
Metadata
Assignees
Labels
No labels