-
Notifications
You must be signed in to change notification settings - Fork 76
Use iteration count instead of time for parameters warmup and rep of do_bench* functions for benchmarks
#2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b1d2a0b
339b709
5ebbd01
b1cc599
0ad146f
bbf0557
81fec9a
42e653a
8f81c13
5d08d3a
b2d3398
fe806b1
bf49b0d
7493632
524f81d
4d40864
b0d91ce
6809b9a
e1c4f9f
a1fd0f9
f16b149
565d87c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,18 +36,18 @@ def _summarize_statistics(times, quantiles, return_mode): | |
| return getattr(torch, return_mode)(times).item() | ||
|
|
||
|
|
||
| def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", | ||
| def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", | ||
| sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument | ||
| """ | ||
| Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with | ||
| the 20-th and 80-th performance percentile. | ||
|
|
||
| :param fn: Function to benchmark | ||
| :type fn: Callable | ||
| :param warmup: Warmup time (in ms) | ||
| :type warmup: int | ||
| :param rep: Repetition time (in ms) | ||
| :type rep: int | ||
| :param n_warmup: Number of repetitions for warmup | ||
| :type n_warmup: int | ||
| :param n_repeat: Number of repetitions to collect measurements | ||
| :type n_repeat: int | ||
| :param grad_to_none: Reset the gradient of the provided tensor to None | ||
| :type grad_to_none: torch.tensor, optional | ||
| :param quantiles: Performance percentile to return in addition to the median. | ||
|
|
@@ -69,20 +69,6 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret | |
| cache_size = 256 * 1024 * 1024 | ||
| cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) | ||
|
|
||
| # Estimate the runtime of the function | ||
| start_event = torch.xpu.Event(enable_timing=True) | ||
| end_event = torch.xpu.Event(enable_timing=True) | ||
| start_event.record() | ||
| for _ in range(5): | ||
| cache.zero_() | ||
| fn() | ||
| end_event.record() | ||
| synchronize() | ||
| estimate_ms = start_event.elapsed_time(end_event) / 5 | ||
|
|
||
| # compute number of warmup and repeat | ||
| n_warmup = max(1, int(warmup / estimate_ms)) | ||
| n_repeat = max(1, int(rep / estimate_ms)) | ||
| # Warm-up | ||
| for _ in range(n_warmup): | ||
| fn() | ||
|
|
@@ -121,18 +107,18 @@ def extract_kernels(funcs): | |
| return _summarize_statistics(times, quantiles, return_mode) | ||
|
|
||
|
|
||
| def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", | ||
| kernel_name=None): # pylint: disable=unused-argument | ||
| def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", | ||
| device="xpu", kernel_name=None): # pylint: disable=unused-argument | ||
| """ | ||
| Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with | ||
| the 20-th and 80-th performance percentile. | ||
|
|
||
| :param fn: Function to benchmark | ||
| :type fn: Callable | ||
| :param warmup: Warmup time (in ms) | ||
| :type warmup: int | ||
| :param rep: Repetition time (in ms) | ||
| :type rep: int | ||
| :param n_warmup: Number of repetitions for warmup | ||
| :type n_warmup: int | ||
| :param n_repeat: Number of repetitions to collect measurements | ||
| :type n_repeat: int | ||
| :param grad_to_none: Reset the gradient of the provided tensor to None | ||
| :type grad_to_none: torch.tensor, optional | ||
| :param quantiles: Performance percentile to return in addition to the median. | ||
|
|
@@ -142,24 +128,49 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N | |
| import torch | ||
| from triton.testing import do_bench as triton_do_bench | ||
|
|
||
| times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, return_mode="all", | ||
| # We maintain a buffer of 256 MB that we clear | ||
| # before each kernel call to make sure that the L2 | ||
| # doesn't contain any input data before the run | ||
| cache_size = 256 * 1024 * 1024 | ||
| cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) | ||
|
|
||
| # Estimate the runtime of the function | ||
| start_event = torch.xpu.Event(enable_timing=True) | ||
| end_event = torch.xpu.Event(enable_timing=True) | ||
| start_event.record() | ||
| for _ in range(5): | ||
| cache.zero_() | ||
| fn() | ||
| end_event.record() | ||
| synchronize() | ||
| estimate_ms = start_event.elapsed_time(end_event) / 5 | ||
|
|
||
| # The cache is also maintained in `triton_do_bench` function, | ||
| # there is no need to duplicate the amount of memory used. | ||
| del cache | ||
|
|
||
| # compute warmup and repeat times | ||
| warmup_time = n_warmup * estimate_ms | ||
| rep_time = n_repeat * estimate_ms | ||
|
Comment on lines
+152
to
+154
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I translate the parameters into those that upstream ( |
||
|
|
||
| times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all", | ||
| device_type=device) | ||
| times = torch.tensor(times, dtype=torch.float) | ||
| return _summarize_statistics(times, quantiles, return_mode) | ||
|
|
||
|
|
||
| def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", | ||
| device="xpu", sync_submitting=True, kernel_name=None): | ||
| def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, | ||
| return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): | ||
| """ | ||
| Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with | ||
| the 20-th and 80-th performance percentile. | ||
|
|
||
| :param fn: Function to benchmark | ||
| :type fn: Callable | ||
| :param warmup: Warmup time (in ms) | ||
| :type warmup: int | ||
| :param rep: Repetition time (in ms) | ||
| :type rep: int | ||
| :param n_warmup: Number of repetitions for warmup | ||
| :type n_warmup: int | ||
| :param n_repeat: Number of repetitions to collect measurements | ||
| :type n_repeat: int | ||
| :param grad_to_none: Reset the gradient of the provided tensor to None | ||
| :type grad_to_none: torch.tensor, optional | ||
| :param quantiles: Performance percentile to return in addition to the median. | ||
|
|
@@ -179,20 +190,6 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None | |
| cache_size = 256 * 1024 * 1024 | ||
| cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) | ||
|
|
||
| # Estimate the runtime of the function | ||
| start_event = torch.xpu.Event(enable_timing=True) | ||
| end_event = torch.xpu.Event(enable_timing=True) | ||
| start_event.record() | ||
| for _ in range(5): | ||
| cache.zero_() | ||
| fn() | ||
| end_event.record() | ||
| synchronize() | ||
| estimate_ms = start_event.elapsed_time(end_event) / 5 | ||
|
|
||
| # compute number of warmup and repeat | ||
| n_warmup = max(1, int(warmup / estimate_ms)) | ||
| n_repeat = max(1, int(rep / estimate_ms)) | ||
|
Comment on lines
-193
to
-195
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no point in calculating the number of iterations through the expected time of one iteration, since the required number of iterations is requested by the user. |
||
| # Warm-up | ||
| for _ in range(n_warmup): | ||
| fn() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no point in calculating the number of iterations through the expected time of one iteration, since the required number of iterations is requested by the user.