Skip to content

Commit f7f6f27

Browse files
Add ScaledDot to micro benchmark (#3604)
Per #3538 (review), split micro benchmark part from #3538 . The kernel is extracted from UT. The reason of add it to micro benchmark is this Op implemtation changed much recently. We have performance concern because there are many layout conversions used in lowering. --------- Co-authored-by: Whitney Tsang <[email protected]>
1 parent b3f9c6e commit f7f6f27

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed

benchmarks/micro_benchmarks/core_ops/__init__.py

Whitespace-only changes.
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
@triton.jit
9+
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
10+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
11+
type_b: tl.constexpr):
12+
DIV_FACTOR_A: tl.constexpr = 2 if type_a == 'e2m1' else 1
13+
DIV_FACTOR_B: tl.constexpr = 2 if type_b == 'e2m1' else 1
14+
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A
15+
PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B
16+
a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_a1
17+
b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1
18+
19+
a = tl.load(a_ptr)
20+
b = tl.load(b_ptr)
21+
SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
22+
if a_scale is not None:
23+
scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :]
24+
a_scale = tl.load(scale_a_ptr)
25+
if b_scale is not None:
26+
scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :]
27+
b_scale = tl.load(scale_b_ptr)
28+
c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b)
29+
out_ptr = out + \
30+
tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + \
31+
tl.arange(0, BLOCK_N)[None, :]
32+
tl.store(out_ptr, c.to(tl.bfloat16))
33+
34+
35+
def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps):
36+
kernel_kwargs = {'num_warps': num_warps}
37+
dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b,
38+
**kernel_kwargs)
39+
40+
41+
# Benchmark Performance
42+
@benchmark_suit.perf_report(
43+
benchmark_suit.Benchmark(
44+
# argument names to use as an x-axis for the plot
45+
x_names=['M', 'K', 'N', 'col_a', 'col_b', 'rhs_scale', 'mxfp_type', 'normal_type'],
46+
x_vals=[(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type)
47+
for M, N, K in [(128, 128, 128)]
48+
for col_a, col_b in [(True, True), (False, False)]
49+
for rhs_scale in [True, False]
50+
for mxfp_type in ['e2m1', 'e4m3']
51+
for normal_type in ['bf16']],
52+
line_arg='provider',
53+
# argument name whose value corresponds to a different line in the plot
54+
# possible values for `line_arg``
55+
line_vals=['triton'],
56+
# label name for the lines
57+
line_names=['Triton'],
58+
# line styles
59+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
60+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
61+
plot_name='scaled-dot',
62+
# name for the plot. Used also as a file name for saving the plot.
63+
args={},
64+
))
65+
def benchmark(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, provider):
66+
67+
device = 'xpu'
68+
num_warps = 4
69+
quantiles = [0.5, 0.0, 1.0]
70+
71+
comp_dtype = torch.float16 if normal_type == 'fp16' else torch.bfloat16
72+
# The max exponent we use to initialize data in the x/y and associated scale tensor to avoid
73+
# overflow when scaling.
74+
comp_dtype_max_exp = 6 if normal_type == 'fp16' else 15
75+
76+
torch.manual_seed(0)
77+
78+
def make_arg(shape, ty, col_major=False):
79+
if col_major:
80+
shape = shape[:-2] + (shape[-1], shape[-2])
81+
if ty in ['fp16', 'bf16']:
82+
ret = torch.randn(shape, dtype=comp_dtype, device=device)
83+
# Clamp to avoid relative error issues
84+
ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1)
85+
else:
86+
ret = torch.randint(256, shape, dtype=torch.uint8, device=device)
87+
if col_major:
88+
ret = ret.mT
89+
return ret
90+
91+
type_a = normal_type if rhs_scale else mxfp_type
92+
type_b = mxfp_type if rhs_scale else normal_type
93+
94+
DIV_FACTOR_A = 2 if type_a == 'e2m1' else 1
95+
DIV_FACTOR_B = 2 if type_b == 'e2m1' else 1
96+
x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a)
97+
y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b)
98+
99+
min_scale, max_scale = (0, 142) if comp_dtype == torch.bfloat16 else (124, 131)
100+
scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device)
101+
scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device)
102+
103+
def make_finite(x, dtype):
104+
# e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
105+
# Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme)
106+
if dtype not in ('e5m2', 'e4m3'):
107+
return x
108+
if dtype == 'e5m2' and comp_dtype == torch.float16:
109+
x = x & 0xB
110+
mask = 0x7C if dtype == 'e5m2' else 0x7F
111+
finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask
112+
x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x)
113+
x.copy_(x_finite)
114+
return x
115+
116+
x = make_finite(x, type_a)
117+
y = make_finite(y, type_b)
118+
z = x.new_empty((M, N), dtype=comp_dtype)
119+
if rhs_scale:
120+
scale_x = None
121+
else:
122+
scale_y = None
123+
124+
if provider == 'triton':
125+
triton_fn = lambda: dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps)
126+
127+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
128+
quantiles=quantiles)
129+
else:
130+
raise NotImplementedError(f'Unsupported provider {provider}')
131+
132+
def tflops(ms):
133+
scale_ops = N * K if rhs_scale else M * K
134+
return (2 * M * N * K + scale_ops) * (1e-12) / (ms * 1e-3)
135+
136+
def gbps(ms):
137+
138+
def size_x(m, n, ty):
139+
if ty in ['e2m1']:
140+
return m * n // 2
141+
if ty in ['e4m3', 'e5m2']:
142+
return m * n
143+
if ty in ['fp16', 'bf16']:
144+
return m * n * 2
145+
raise NotImplementedError(f'Unsupported type {ty} for scaledot operand')
146+
147+
tensor_size = size_x(M, K, type_a) + size_x(K, N, type_b)
148+
scale_size = (M * K // 32) if rhs_scale else (N * K // 32)
149+
return (tensor_size + scale_size + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
150+
151+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
152+
153+
154+
if __name__ == '__main__':
155+
benchmark.run(show_plots=False, print_data=True)

benchmarks/micro_benchmarks/run_benchmarks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22

33
from conversion import float_conversion
4+
from core_ops import dot_scaled
45

56
if __name__ == '__main__':
67
parser = argparse.ArgumentParser()
@@ -12,3 +13,4 @@
1213
)
1314
args = parser.parse_args()
1415
float_conversion.benchmark.run(print_data=True, save_path=args.reports)
16+
dot_scaled.benchmark.run(print_data=True, save_path=args.reports)

0 commit comments

Comments
 (0)