-
Notifications
You must be signed in to change notification settings - Fork 22
JAX FA Benchmarking Script #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
ipanfilo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why ck_fused_attn_bwd needs modification if the PR is for fwd-pass only?
|
Note that I have added the BWD pass implementation as well to this PR. |
|
Pinging @wangye805 @wenchenvincent in case either of you are interested in reviewing this PR as well, thanks! |
|
Have you compared the kernel time measured from CK FA API vs from rocprof? |
@Micky774 Could you follow up? Let's try to close this PR soon. |
|
@wenchenvincent just confirmed the timings reported match those recorded via |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: