-
Notifications
You must be signed in to change notification settings - Fork 68
[benchmarks] Reworked the conversion benchmark and added more tests for up/down casts #4800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from .float_conversion import benchmark # type: ignore # noqa: F401 | ||
| from .float_conversion import get_benchmarks, run_benchmarks # type: ignore # noqa: F401 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,68 +1,111 @@ | ||
| import os | ||
| import sys | ||
| from functools import lru_cache | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) | ||
| from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413 | ||
|
|
||
| TYPES = { | ||
| tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16: | ||
| torch.bfloat16, tl.float32: torch.float32 | ||
| } | ||
|
|
||
|
|
||
| @lru_cache | ||
| def get_kernel(name): | ||
|
|
||
| def kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| n_elements, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| x_type: tl.constexpr, | ||
| y_type: tl.constexpr, | ||
| rnd: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| block_start = pid * BLOCK_SIZE | ||
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < n_elements | ||
| x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32 | ||
| y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32 | ||
|
|
||
| x = tl.load(x_ptr + offsets, mask=mask) | ||
| converted = x.to(y_type, fp_downcast_rounding=rnd) | ||
| x = tl.cast(x, x_itype, bitcast=True) | ||
| y = tl.cast(converted, y_itype, bitcast=True) | ||
| for i in range(99): | ||
| x += tl.full(x.shape, i, x_itype) | ||
| converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd) | ||
| y += tl.cast(converted, y_itype, bitcast=True) | ||
| y = tl.cast(y, y_type, bitcast=True) | ||
| tl.store(y_ptr + offsets, y, mask=mask) | ||
|
|
||
| kernel.__name__ = kernel.__qualname__ = name | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why we want to change the kernel name here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes much easier finding and comparing the kernel IRs in the cache. For example, I want to compare the rtne and rtz IRs: $ find ~/.triton/cache/ -name '*fp32_to_fp16*.llir'
/home/jovyan/.triton/cache/PXX2VZY5SPACCMUKHEIEDQEEB2CN2A562JHDMIVPLHTS6F6LSRAA/rtne_fp32_to_fp16_conversion_kernel.llir
/home/jovyan/.triton/cache/CYCEIIP4OQCRWTMZHZHLPXO63ZR5J7ZEPN4KCM72NVOITFO5FMRA/rtz_fp32_to_fp16_conversion_kernel.llirThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep this in develop stage, since from the production perspective, there is not performance or visualization improvement from this change. But it breaks the code structure to be incosistant with other benchmarks, which reduces code readability. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's very convenient using benchmarks for development and testing. This bench produces 20 kernels and it's not easy to recognize the required kernel in the cache. |
||
| return triton.jit(kernel) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The decorator does not allow changing the name. |
||
|
|
||
|
|
||
| def get_bench(x_type, y_type): | ||
| assert x_type.itemsize < y_type.itemsize | ||
| plot_name = f'{x_type}-{y_type}' | ||
| line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')] | ||
| line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne'] | ||
| if y_type == tl.float32: | ||
| line_vals.append((y_type, x_type, 'rtz')) | ||
| line_names.append(f'{y_type}->{x_type}-rtz') | ||
|
|
||
| @perf_report( | ||
| Benchmark( | ||
| x_names=['N'], | ||
| x_vals=[2**i for i in range(12, 28, 2)], | ||
| line_arg='args', | ||
| line_vals=line_vals, | ||
| line_names=line_names, | ||
| styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')], | ||
| ylabel=('GB/s', ), | ||
| plot_name=plot_name, | ||
| args={}, | ||
| )) | ||
| def bench(N, args): | ||
| quantiles = [0.5, 0.2, 0.8] | ||
| x_type = args[0] | ||
| y_type = args[1] | ||
| if x_type.itemsize == 1: | ||
| x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type]) | ||
| else: | ||
| x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True) | ||
| y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu') | ||
| rnd = args[2] if x_type.itemsize > y_type.itemsize else None | ||
| name = f'{x_type}_to_{y_type}_conversion_kernel' | ||
| if rnd: | ||
| name = f'{rnd}_{name}' | ||
| kernel = get_kernel(name) | ||
|
|
||
| def fwd(): | ||
| BLOCK_SIZE = 4096 | ||
| grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) | ||
| kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd) | ||
| return x | ||
|
|
||
| _, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles) | ||
| gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3) | ||
| return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv | ||
|
|
||
| return bench | ||
|
|
||
|
|
||
| def get_benchmarks(): | ||
| return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize] | ||
|
|
||
|
|
||
| @triton.jit | ||
| def float_trunc_kernel( | ||
| x_ptr, | ||
| n_elements, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| target_type: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| block_start = pid * BLOCK_SIZE | ||
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < n_elements | ||
|
|
||
| x = tl.load(x_ptr + offsets, mask=mask) | ||
|
|
||
| as_target = x.to(target_type) | ||
| as_f32 = as_target.to(tl.float32) | ||
| for _ in range(100): | ||
| as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed | ||
| as_target = as_f32.to(target_type) | ||
| as_f32 = as_target.to(tl.float32) | ||
|
|
||
| tl.store(x_ptr + offsets, as_f32, mask=mask) | ||
|
|
||
|
|
||
| def launch_conversion(x: torch.Tensor, target_type: type): | ||
| assert x.is_xpu | ||
| n_elements = x.numel() | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) | ||
| float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type) | ||
| return x | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=['N'], | ||
| x_vals=[2**i for i in range(12, 28, 2)], | ||
| line_arg='target_type', | ||
| line_vals=['bfloat16', 'float16'], | ||
| line_names=['BF16', 'FP16'], | ||
| styles=[('blue', '-'), ('green', '-'), ('orange', '-')], | ||
| ylabel='GB/s', | ||
| plot_name='float-conversion', | ||
| args={}, | ||
| )) | ||
| def benchmark(N, target_type): | ||
| quantiles = [0.5, 0.2, 0.8] | ||
| inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True) | ||
|
|
||
| if target_type == 'bfloat16': | ||
| fwd = lambda: launch_conversion(inputs, tl.bfloat16) | ||
| elif target_type == 'float16': | ||
| fwd = lambda: launch_conversion(inputs, tl.float16) | ||
| else: | ||
| raise NotImplementedError(f'Type {target_type} is not supported') | ||
|
|
||
| ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles) | ||
| gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3) | ||
|
|
||
| return gbps(ms), gbps(max_ms), gbps(min_ms) | ||
| def run_benchmarks(): | ||
| for bench in get_benchmarks(): | ||
| bench.run(print_data=True) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| benchmark.run(print_data=True) | ||
| run_benchmarks() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,6 @@ | ||
| import argparse | ||
|
|
||
| from conversion import float_conversion | ||
| from core_ops import dot_scaled | ||
|
|
||
| if __name__ == '__main__': | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| '--reports', | ||
| type=str, | ||
| default='', | ||
| help='directory to save reports', | ||
| ) | ||
| args = parser.parse_args() | ||
| float_conversion.benchmark.run(print_data=True, save_path=args.reports) | ||
| dot_scaled.benchmark.run(print_data=True, save_path=args.reports) | ||
| for mod in (float_conversion, dot_scaled): | ||
| mod.run_benchmarks() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton has its cache mechanism. You can check the kernel jit cache under
TRITON_CACHE_DIRif specified. We do not need to uselru_cache.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but JITFunction object is created on each call to _make_kernel(). I'm not sure if it has significant impact on the bench results.