diff --git a/benchmarks/micro_benchmarks/conversion/float_conversion/__init__.py b/benchmarks/micro_benchmarks/conversion/float_conversion/__init__.py index aa58e8d853..e7fcf62390 100644 --- a/benchmarks/micro_benchmarks/conversion/float_conversion/__init__.py +++ b/benchmarks/micro_benchmarks/conversion/float_conversion/__init__.py @@ -1 +1 @@ -from .float_conversion import benchmark # type: ignore # noqa: F401 +from .float_conversion import get_benchmarks, run_benchmarks # type: ignore # noqa: F401 diff --git a/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py b/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py index 84032bd9b6..8dc38dc7ca 100644 --- a/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py +++ b/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py @@ -1,68 +1,111 @@ +import os +import sys +from functools import lru_cache + import torch import triton import triton.language as tl +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) +from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413 + +TYPES = { + tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16: + torch.bfloat16, tl.float32: torch.float32 +} + + +@lru_cache +def get_kernel(name): + + def kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + x_type: tl.constexpr, + y_type: tl.constexpr, + rnd: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32 + y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32 + + x = tl.load(x_ptr + offsets, mask=mask) + converted = x.to(y_type, fp_downcast_rounding=rnd) + x = tl.cast(x, x_itype, bitcast=True) + y = tl.cast(converted, y_itype, bitcast=True) + for i in range(99): + x += tl.full(x.shape, i, x_itype) + converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd) + y += tl.cast(converted, y_itype, bitcast=True) + y = tl.cast(y, y_type, bitcast=True) + tl.store(y_ptr + offsets, y, mask=mask) + + kernel.__name__ = kernel.__qualname__ = name + return triton.jit(kernel) + + +def get_bench(x_type, y_type): + assert x_type.itemsize < y_type.itemsize + plot_name = f'{x_type}-{y_type}' + line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')] + line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne'] + if y_type == tl.float32: + line_vals.append((y_type, x_type, 'rtz')) + line_names.append(f'{y_type}->{x_type}-rtz') + + @perf_report( + Benchmark( + x_names=['N'], + x_vals=[2**i for i in range(12, 28, 2)], + line_arg='args', + line_vals=line_vals, + line_names=line_names, + styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')], + ylabel=('GB/s', ), + plot_name=plot_name, + args={}, + )) + def bench(N, args): + quantiles = [0.5, 0.2, 0.8] + x_type = args[0] + y_type = args[1] + if x_type.itemsize == 1: + x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type]) + else: + x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True) + y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu') + rnd = args[2] if x_type.itemsize > y_type.itemsize else None + name = f'{x_type}_to_{y_type}_conversion_kernel' + if rnd: + name = f'{rnd}_{name}' + kernel = get_kernel(name) + + def fwd(): + BLOCK_SIZE = 4096 + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) + kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd) + return x + + _, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles) + gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3) + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv + + return bench + + +def get_benchmarks(): + return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize] + -@triton.jit -def float_trunc_kernel( - x_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - target_type: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - - as_target = x.to(target_type) - as_f32 = as_target.to(tl.float32) - for _ in range(100): - as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed - as_target = as_f32.to(target_type) - as_f32 = as_target.to(tl.float32) - - tl.store(x_ptr + offsets, as_f32, mask=mask) - - -def launch_conversion(x: torch.Tensor, target_type: type): - assert x.is_xpu - n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type) - return x - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['N'], - x_vals=[2**i for i in range(12, 28, 2)], - line_arg='target_type', - line_vals=['bfloat16', 'float16'], - line_names=['BF16', 'FP16'], - styles=[('blue', '-'), ('green', '-'), ('orange', '-')], - ylabel='GB/s', - plot_name='float-conversion', - args={}, - )) -def benchmark(N, target_type): - quantiles = [0.5, 0.2, 0.8] - inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True) - - if target_type == 'bfloat16': - fwd = lambda: launch_conversion(inputs, tl.bfloat16) - elif target_type == 'float16': - fwd = lambda: launch_conversion(inputs, tl.float16) - else: - raise NotImplementedError(f'Type {target_type} is not supported') - - ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles) - gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3) - - return gbps(ms), gbps(max_ms), gbps(min_ms) +def run_benchmarks(): + for bench in get_benchmarks(): + bench.run(print_data=True) if __name__ == '__main__': - benchmark.run(print_data=True) + run_benchmarks() diff --git a/benchmarks/micro_benchmarks/core_ops/dot_scaled.py b/benchmarks/micro_benchmarks/core_ops/dot_scaled.py index 6bcca472d2..76e80d16f9 100644 --- a/benchmarks/micro_benchmarks/core_ops/dot_scaled.py +++ b/benchmarks/micro_benchmarks/core_ops/dot_scaled.py @@ -1,7 +1,13 @@ +import os +import sys + import torch import triton import triton.language as tl +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413 + @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, @@ -37,8 +43,8 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps): # Benchmark Performance -@triton.testing.perf_report( - triton.testing.Benchmark( +@perf_report( + Benchmark( # argument names to use as an x-axis for the plot x_names=['M', 'K', 'N', 'col_a', 'col_b', 'rhs_scale', 'mxfp_type', 'normal_type'], x_vals=[(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type) @@ -55,7 +61,7 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps): line_names=['Triton'], # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], - ylabel=['GB/s', 'TFlops'], # label name for the y-axis + ylabel=('GB/s', ), # label name for the y-axis plot_name='scaled-dot', # name for the plot. Used also as a file name for saving the plot. args={}, @@ -122,7 +128,7 @@ def make_finite(x, dtype): if provider == 'triton': triton_fn = lambda: dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps) - ms, min_ms, max_ms = triton.testing.do_bench(triton_fn, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') @@ -141,8 +147,12 @@ def size_x(m, n, ty): scale_size = (M * K // 32) if rhs_scale else (N * K // 32) return (tensor_size + scale_size + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3) - return gbps(ms), gbps(max_ms), gbps(min_ms) + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv -if __name__ == '__main__': +def run_benchmarks(): benchmark.run(show_plots=False, print_data=True) + + +if __name__ == '__main__': + run_benchmarks() diff --git a/benchmarks/micro_benchmarks/run_benchmarks.py b/benchmarks/micro_benchmarks/run_benchmarks.py index 4f3aad8004..c18103fac7 100644 --- a/benchmarks/micro_benchmarks/run_benchmarks.py +++ b/benchmarks/micro_benchmarks/run_benchmarks.py @@ -1,16 +1,6 @@ -import argparse - from conversion import float_conversion from core_ops import dot_scaled if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--reports', - type=str, - default='', - help='directory to save reports', - ) - args = parser.parse_args() - float_conversion.benchmark.run(print_data=True, save_path=args.reports) - dot_scaled.benchmark.run(print_data=True, save_path=args.reports) + for mod in (float_conversion, dot_scaled): + mod.run_benchmarks()