Skip to content

[benchmarks] Reworked the conversion benchmark and added more tests for up/down casts #4800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .float_conversion import benchmark # type: ignore # noqa: F401
from .float_conversion import get_benchmarks, run_benchmarks # type: ignore # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,68 +1,111 @@
import os
import sys
from functools import lru_cache

import torch
import triton
import triton.language as tl

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413

TYPES = {
tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16:
torch.bfloat16, tl.float32: torch.float32
}


@lru_cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton has its cache mechanism. You can check the kernel jit cache under TRITON_CACHE_DIR if specified. We do not need to use lru_cache.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but JITFunction object is created on each call to _make_kernel(). I'm not sure if it has significant impact on the bench results.

def get_kernel(name):

def kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
x_type: tl.constexpr,
y_type: tl.constexpr,
rnd: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32
y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32

x = tl.load(x_ptr + offsets, mask=mask)
converted = x.to(y_type, fp_downcast_rounding=rnd)
x = tl.cast(x, x_itype, bitcast=True)
y = tl.cast(converted, y_itype, bitcast=True)
for i in range(99):
x += tl.full(x.shape, i, x_itype)
converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd)
y += tl.cast(converted, y_itype, bitcast=True)
y = tl.cast(y, y_type, bitcast=True)
tl.store(y_ptr + offsets, y, mask=mask)

kernel.__name__ = kernel.__qualname__ = name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we want to change the kernel name here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes much easier finding and comparing the kernel IRs in the cache. For example, I want to compare the rtne and rtz IRs:

$ find ~/.triton/cache/ -name '*fp32_to_fp16*.llir'
/home/jovyan/.triton/cache/PXX2VZY5SPACCMUKHEIEDQEEB2CN2A562JHDMIVPLHTS6F6LSRAA/rtne_fp32_to_fp16_conversion_kernel.llir
/home/jovyan/.triton/cache/CYCEIIP4OQCRWTMZHZHLPXO63ZR5J7ZEPN4KCM72NVOITFO5FMRA/rtz_fp32_to_fp16_conversion_kernel.llir

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this in develop stage, since from the production perspective, there is not performance or visualization improvement from this change. But it breaks the code structure to be incosistant with other benchmarks, which reduces code readability.
Without this change, we also do not need lru_cache, _make_kernel and could keep using triton.jit as decorator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very convenient using benchmarks for development and testing. This bench produces 20 kernels and it's not easy to recognize the required kernel in the cache.

return triton.jit(kernel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using triton.jit in decorator as other benchmarks do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decorator does not allow changing the name.



def get_bench(x_type, y_type):
assert x_type.itemsize < y_type.itemsize
plot_name = f'{x_type}-{y_type}'
line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')]
line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne']
if y_type == tl.float32:
line_vals.append((y_type, x_type, 'rtz'))
line_names.append(f'{y_type}->{x_type}-rtz')

@perf_report(
Benchmark(
x_names=['N'],
x_vals=[2**i for i in range(12, 28, 2)],
line_arg='args',
line_vals=line_vals,
line_names=line_names,
styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')],
ylabel=('GB/s', ),
plot_name=plot_name,
args={},
))
def bench(N, args):
quantiles = [0.5, 0.2, 0.8]
x_type = args[0]
y_type = args[1]
if x_type.itemsize == 1:
x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type])
else:
x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True)
y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu')
rnd = args[2] if x_type.itemsize > y_type.itemsize else None
name = f'{x_type}_to_{y_type}_conversion_kernel'
if rnd:
name = f'{rnd}_{name}'
kernel = get_kernel(name)

def fwd():
BLOCK_SIZE = 4096
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd)
return x

_, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles)
gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3)
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv

return bench


def get_benchmarks():
return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize]


@triton.jit
def float_trunc_kernel(
x_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
target_type: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

x = tl.load(x_ptr + offsets, mask=mask)

as_target = x.to(target_type)
as_f32 = as_target.to(tl.float32)
for _ in range(100):
as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed
as_target = as_f32.to(target_type)
as_f32 = as_target.to(tl.float32)

tl.store(x_ptr + offsets, as_f32, mask=mask)


def launch_conversion(x: torch.Tensor, target_type: type):
assert x.is_xpu
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type)
return x


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[2**i for i in range(12, 28, 2)],
line_arg='target_type',
line_vals=['bfloat16', 'float16'],
line_names=['BF16', 'FP16'],
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='float-conversion',
args={},
))
def benchmark(N, target_type):
quantiles = [0.5, 0.2, 0.8]
inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True)

if target_type == 'bfloat16':
fwd = lambda: launch_conversion(inputs, tl.bfloat16)
elif target_type == 'float16':
fwd = lambda: launch_conversion(inputs, tl.float16)
else:
raise NotImplementedError(f'Type {target_type} is not supported')

ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles)
gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3)

return gbps(ms), gbps(max_ms), gbps(min_ms)
def run_benchmarks():
for bench in get_benchmarks():
bench.run(print_data=True)


if __name__ == '__main__':
benchmark.run(print_data=True)
run_benchmarks()
22 changes: 16 additions & 6 deletions benchmarks/micro_benchmarks/core_ops/dot_scaled.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
import sys

import torch
import triton
import triton.language as tl

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413


@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
Expand Down Expand Up @@ -37,8 +43,8 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps):


# Benchmark Performance
@triton.testing.perf_report(
triton.testing.Benchmark(
@perf_report(
Benchmark(
# argument names to use as an x-axis for the plot
x_names=['M', 'K', 'N', 'col_a', 'col_b', 'rhs_scale', 'mxfp_type', 'normal_type'],
x_vals=[(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type)
Expand All @@ -55,7 +61,7 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps):
line_names=['Triton'],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
ylabel=('GB/s', ), # label name for the y-axis
plot_name='scaled-dot',
# name for the plot. Used also as a file name for saving the plot.
args={},
Expand Down Expand Up @@ -122,7 +128,7 @@ def make_finite(x, dtype):
if provider == 'triton':
triton_fn = lambda: dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps)

ms, min_ms, max_ms = triton.testing.do_bench(triton_fn, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand All @@ -141,8 +147,12 @@ def size_x(m, n, ty):
scale_size = (M * K // 32) if rhs_scale else (N * K // 32)
return (tensor_size + scale_size + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)

return gbps(ms), gbps(max_ms), gbps(min_ms)
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv


if __name__ == '__main__':
def run_benchmarks():
benchmark.run(show_plots=False, print_data=True)


if __name__ == '__main__':
run_benchmarks()
14 changes: 2 additions & 12 deletions benchmarks/micro_benchmarks/run_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import argparse

from conversion import float_conversion
from core_ops import dot_scaled

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--reports',
type=str,
default='',
help='directory to save reports',
)
args = parser.parse_args()
float_conversion.benchmark.run(print_data=True, save_path=args.reports)
dot_scaled.benchmark.run(print_data=True, save_path=args.reports)
for mod in (float_conversion, dot_scaled):
mod.run_benchmarks()