Skip to content

Conversation

Egor-Krivov
Copy link
Contributor

@Egor-Krivov Egor-Krivov commented Aug 15, 2025

Flex attention requires more warmup steps on B580.

PR adds:

  1. Pre-warmup step for flex attention that is called once per run, so it will only run for the first shape config. Experiments show that first config requires more warmup
  2. Makes GPU synch consistent between warmup and benchmarking
  3. Adds iterations

Should resolve #4852

Better warmup should be done after researching in #4911

@Egor-Krivov
Copy link
Contributor Author

Egor-Krivov commented Aug 18, 2025

@Egor-Krivov Egor-Krivov marked this pull request as ready for review August 18, 2025 11:09
@whitneywhtsang
Copy link
Contributor

Also, based on experiments, warmup has to contain profiling, otherwise first iterations with profiling will be slower.

That's interesting, do you have some ideas on why?

@Egor-Krivov
Copy link
Contributor Author

@Egor-Krivov
Copy link
Contributor Author

Also, based on experiments, warmup has to contain profiling, otherwise first iterations with profiling will be slower.

That's interesting, do you have some ideas on why?

I've rechecked just in case, and I discovered that having profiling during warmup is not required. However, warmup combined with profiling means that I need 20-80 warmup steps in 98% of cases. If warmup is without profiling, I need about 150-200 warmup steps. Originally I only tried warmup steps up to 100.

Maybe some optimizations are tied to the total time spent in this kernel (either to calculate optimizations or to check if it's worth it), and having warmup + profiling + cache rewrite means that we simply have more time for warmup.

@Egor-Krivov
Copy link
Contributor Author

@Egor-Krivov
Copy link
Contributor Author

@Egor-Krivov
Copy link
Contributor Author

Looks like flex attention results are stable now:
https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17073456205
https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17066951767

I'll investigate proper warmup for other kernels in another issue #4911

@whitneywhtsang Will you take a look?

is_bmg = any(name in torch.xpu.get_device_name().lower() for name in ('b570', 'b580'))
if is_bmg:
benchmark_suit.do_prewarmup(triton_fn)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200 if is_bmg else 10, n_repeat=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need both prewarmup and increasing n_warmup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, first warmup across all shapes takes a lot of time. Just setting n_warmup to 200 is not always enough

device=DEVICE)
# Need more warmups on B580 due to the torch.compile

is_bmg = any(name in torch.xpu.get_device_name().lower() for name in ('b570', 'b580'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep it simple and increase across platforms, no need to check if it is bmg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@Egor-Krivov
Copy link
Contributor Author

Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also change PR description, it is now out of date.

We agree to land this as is to unblock flex attn BMG measurement, and further improve the warmup mechanism in a separate PR.

@Egor-Krivov Egor-Krivov merged commit 3714e9b into main Aug 22, 2025
16 of 18 checks passed
@Egor-Krivov Egor-Krivov deleted the egor/flex_attn_fluc branch August 22, 2025 10:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FlexAttention] Investigate performance fluctuation on BMG
2 participants