-
Notifications
You must be signed in to change notification settings - Fork 69
[CI] Better warmup for flex attention on B580 #4906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
That's interesting, do you have some ideas on why? |
I've rechecked just in case, and I discovered that having profiling during warmup is not required. However, warmup combined with profiling means that I need 20-80 warmup steps in 98% of cases. If warmup is without profiling, I need about 150-200 warmup steps. Originally I only tried warmup steps up to 100. Maybe some optimizations are tied to the total time spent in this kernel (either to calculate optimizations or to check if it's worth it), and having warmup + profiling + cache rewrite means that we simply have more time for warmup. |
Looks like flex attention results are stable now: I'll investigate proper warmup for other kernels in another issue #4911 @whitneywhtsang Will you take a look? |
is_bmg = any(name in torch.xpu.get_device_name().lower() for name in ('b570', 'b580')) | ||
if is_bmg: | ||
benchmark_suit.do_prewarmup(triton_fn) | ||
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200 if is_bmg else 10, n_repeat=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need both prewarmup and increasing n_warmup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, first warmup across all shapes takes a lot of time. Just setting n_warmup to 200 is not always enough
device=DEVICE) | ||
# Need more warmups on B580 due to the torch.compile | ||
|
||
is_bmg = any(name in torch.xpu.get_device_name().lower() for name in ('b570', 'b580')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep it simple and increase across platforms, no need to check if it is bmg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also change PR description, it is now out of date.
We agree to land this as is to unblock flex attn BMG measurement, and further improve the warmup mechanism in a separate PR.
Flex attention requires more warmup steps on B580.
PR adds:
Should resolve #4852
Better warmup should be done after researching in #4911