-
Notifications
You must be signed in to change notification settings - Fork 23
JAX FA Benchmarking Script #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
| @@ -0,0 +1,307 @@ | |||
| # This file was modified for portability to AMDGPU | |||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
| # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is NVIDIA code this benchmark is based on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's based on the JAX FA tests -- not sure how we handle the copyright in that case, since this is a new file.
|
|
||
| // print kernel name on verbose mode | ||
| ck_tile::stream_config stream_config{stream, false, ck_fused_attn_log_config}; | ||
| ck_tile::stream_config stream_config{stream, true, ck_fused_attn_log_config}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK to always enable kernels timing? Doesn't it have some performance penalty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good catch, this is fixed in the newest commit
ipanfilo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why ck_fused_attn_bwd needs modification if the PR is for fwd-pass only?
|
Note that I have added the BWD pass implementation as well to this PR. |
|
Pinging @wangye805 @wenchenvincent in case either of you are interested in reviewing this PR as well, thanks! |
benchmarks/attention/README.md
Outdated
| ## JAX Fused-Attention Benchmarking | ||
| The benchmarking process is split into two stages: *generating* the timing data, and *visualizing* the timing data. The following steps assume you are located in `TransformerEngine/benchmarks/attention` (i.e. where this README is located). First, ensure that you install requirements via `pip install -r requirements.txt`. | ||
|
|
||
| Note: Only forward timings are supported at this point. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| from transformer_engine.jax import fp8_autocast | ||
|
|
||
| # Needed in order to dump timings properly | ||
| os.environ["XLA_FLAGS"]="--xla_gpu_graph_level=0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because you used dumping time function in ck fused attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
| file << average_runtime << "\n"; | ||
| } | ||
| void dump_bwd_timings(const char* dump_path, float average_runtime, hipStream_t stream){ | ||
| std::ofstream file; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an option - open file once and do not close it so no open/flush/close happen on every dump
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: