|
| 1 | +""" |
| 2 | +Gemm benchmark (tensor of pointer) |
| 3 | +============================ |
| 4 | +
|
| 5 | +This benchmark is come from the Triton tutorial 03-matrix-multiplication.py (commit: 3f4fdd1) |
| 6 | +To compare the performance to XeTLA kernel. |
| 7 | +
|
| 8 | +""" |
| 9 | +import os |
| 10 | + |
| 11 | +import torch |
| 12 | +import triton |
| 13 | +import triton.language as tl |
| 14 | + |
| 15 | +import triton_kernels_benchmark as benchmark_suit |
| 16 | +from triton_kernels_benchmark import xetla_kernel |
| 17 | + |
| 18 | +TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1' |
| 19 | +TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1' |
| 20 | +use_xetla = not (TRANSPOSE_A or TRANSPOSE_B) |
| 21 | +SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' |
| 22 | + |
| 23 | + |
| 24 | +@triton.autotune( |
| 25 | + configs=[ |
| 26 | + triton.Config( |
| 27 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 28 | + num_stages=s, num_warps=32) for s in [1, 2, 3] |
| 29 | + ] + [ |
| 30 | + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
| 31 | + num_stages=s, num_warps=w) |
| 32 | + for s in [2, 3, 4] |
| 33 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
| 34 | + ] + [ |
| 35 | + triton.Config( |
| 36 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 37 | + num_stages=s, num_warps=32) for s in [2] |
| 38 | + ] + [ |
| 39 | + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m}, |
| 40 | + num_stages=s, num_warps=w) |
| 41 | + for s in [2, 3] |
| 42 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
| 43 | + ], |
| 44 | + key=['M', 'N', 'K'], |
| 45 | +) |
| 46 | +@triton.jit |
| 47 | +def matmul_kernel( |
| 48 | + # Pointers to matrices |
| 49 | + a_ptr, b_ptr, c_ptr, |
| 50 | + # Matrix dimensions |
| 51 | + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| 52 | + # Stride variables |
| 53 | + stride_am: tl.constexpr, stride_ak: tl.constexpr, # |
| 54 | + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # |
| 55 | + stride_cm: tl.constexpr, stride_cn: tl.constexpr, |
| 56 | + # Meta-parameters |
| 57 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): |
| 58 | + pid = tl.program_id(axis=0) |
| 59 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 60 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 61 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 62 | + group_id = pid // num_pid_in_group |
| 63 | + first_pid_m = group_id * GROUP_SIZE_M |
| 64 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 65 | + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
| 66 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 67 | + |
| 68 | + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 69 | + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 70 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 71 | + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 72 | + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 73 | + |
| 74 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 75 | + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 76 | + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) |
| 77 | + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
| 78 | + accumulator = tl.dot(a, b, accumulator) |
| 79 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 80 | + b_ptrs += BLOCK_SIZE_K * stride_bk |
| 81 | + |
| 82 | + c = accumulator.to(tl.float32) |
| 83 | + |
| 84 | + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 85 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 86 | + c_ptrs = c_ptr + stride_cm * \ |
| 87 | + offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 88 | + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 89 | + tl.store(c_ptrs, c, mask=c_mask) |
| 90 | + |
| 91 | + |
| 92 | +# pylint: disable=unused-argument |
| 93 | +@triton.autotune( |
| 94 | + configs=[ |
| 95 | + triton.Config( |
| 96 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 97 | + num_stages=s, num_warps=32) for s in [2, 3] |
| 98 | + ] + [ |
| 99 | + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
| 100 | + num_stages=s, num_warps=w) |
| 101 | + for s in [2] |
| 102 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
| 103 | + ] + [ |
| 104 | + triton.Config( |
| 105 | + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 106 | + num_stages=s, num_warps=32) for s in [2, 3] |
| 107 | + ] + [ |
| 108 | + triton.Config( |
| 109 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 110 | + num_stages=s, num_warps=32) for s in [2] |
| 111 | + ] + [ |
| 112 | + triton.Config( |
| 113 | + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, |
| 114 | + num_stages=s, num_warps=32) for s in [2] |
| 115 | + ] + [ |
| 116 | + triton.Config( |
| 117 | + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, |
| 118 | + num_stages=s, num_warps=4) for s in [2] |
| 119 | + ], |
| 120 | + key=['M', 'N', 'K'], |
| 121 | +) |
| 122 | +@triton.jit |
| 123 | +def matmul_kernel_batched( |
| 124 | + # Pointers to matrices |
| 125 | + a_ptr, b_ptr, c_ptr, |
| 126 | + # Matrix dimensions |
| 127 | + B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| 128 | + # Stride variables |
| 129 | + stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, # |
| 130 | + stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # |
| 131 | + stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, |
| 132 | + # Meta-parameters |
| 133 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): |
| 134 | + bid = tl.program_id(axis=1) |
| 135 | + pid = tl.program_id(axis=0) |
| 136 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 137 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 138 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 139 | + group_id = pid // num_pid_in_group |
| 140 | + first_pid_m = group_id * GROUP_SIZE_M |
| 141 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 142 | + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
| 143 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 144 | + |
| 145 | + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 146 | + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 147 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 148 | + offset_a = bid.to(tl.int64) * stride_az |
| 149 | + offset_b = bid.to(tl.int64) * stride_bz |
| 150 | + a_ptrs = a_ptr + offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 151 | + b_ptrs = b_ptr + offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 152 | + |
| 153 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 154 | + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 155 | + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) |
| 156 | + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
| 157 | + accumulator = tl.dot(a, b, accumulator) |
| 158 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 159 | + b_ptrs += BLOCK_SIZE_K * stride_bk |
| 160 | + c = accumulator.to(tl.float32) |
| 161 | + |
| 162 | + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 163 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 164 | + offset_c = bid.to(tl.int64) * stride_cz |
| 165 | + c_ptrs = c_ptr + offset_c + stride_cm * \ |
| 166 | + offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 167 | + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 168 | + tl.store(c_ptrs, c, mask=c_mask) |
| 169 | + |
| 170 | + |
| 171 | +# We can now create a convenience wrapper function that only takes two input tensors, |
| 172 | +# and (1) checks any shape constraint; (2) launches the above kernel. |
| 173 | +def matmul(a, b, c, transpose_a=False, transpose_b=False): |
| 174 | + a_major, a_minor = -2, -1 |
| 175 | + if transpose_a: |
| 176 | + a_major, a_minor = a_minor, a_major |
| 177 | + b_minor, b_major = -2, -1 |
| 178 | + if transpose_b: |
| 179 | + b_major, b_minor = b_minor, b_major |
| 180 | + |
| 181 | + assert a.shape[a_minor] == b.shape[b_minor], 'Incompatible dimensions' |
| 182 | + assert a.is_contiguous(), 'Matrix A must be contiguous' |
| 183 | + assert b.is_contiguous(), 'Matrix B must be contiguous' |
| 184 | + M, N, K = a.shape[a_major], b.shape[b_major], a.shape[a_minor] |
| 185 | + # Check constraints. |
| 186 | + if len(a.shape) == 3 and len(b.shape) == 3: |
| 187 | + assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' |
| 188 | + B = a.shape[0] |
| 189 | + # 1D launch kernel where each block gets its own program. |
| 190 | + |
| 191 | + grid = lambda META: ( |
| 192 | + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), |
| 193 | + B, |
| 194 | + ) |
| 195 | + matmul_kernel_batched[grid]( |
| 196 | + a, b, c, # |
| 197 | + B, M, N, K, # |
| 198 | + a.stride(0), a.stride(a_major), a.stride(a_minor), # |
| 199 | + b.stride(0), b.stride(b_minor), b.stride(b_major), # |
| 200 | + c.stride(0), c.stride(1), c.stride(2)) |
| 201 | + elif len(a.shape) == 2 and len(b.shape) == 2: |
| 202 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) |
| 203 | + matmul_kernel[grid]( |
| 204 | + a, b, c, # |
| 205 | + M, N, K, # |
| 206 | + a.stride(a_major), a.stride(a_minor), # |
| 207 | + b.stride(b_minor), b.stride(b_major), # |
| 208 | + c.stride(0), c.stride(1)) |
| 209 | + else: |
| 210 | + assert False, 'Input matrixs dimensions mismatch' |
| 211 | + return c |
| 212 | + |
| 213 | + |
| 214 | +def get_shapes(B, M, N, K, transpose_a, transpose_b): |
| 215 | + a_shape = (M, K) |
| 216 | + if transpose_a: |
| 217 | + a_shape = (K, M) |
| 218 | + |
| 219 | + b_shape = (K, N) |
| 220 | + if transpose_b: |
| 221 | + b_shape = (N, K) |
| 222 | + |
| 223 | + if B != 1: |
| 224 | + a_shape = (B, *a_shape) |
| 225 | + b_shape = (B, *b_shape) |
| 226 | + return a_shape, b_shape |
| 227 | + |
| 228 | + |
| 229 | +X_VALS = [[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [ |
| 230 | + [1, 1, 13824, 5120], |
| 231 | + [1, 4, 12288, 4096], |
| 232 | + [1, 512, 8192, 8192], |
| 233 | + [1, 512, 8192, 32768], |
| 234 | + [1, 512, 32768, 8192], |
| 235 | + [1, 1024, 8192, 16384], |
| 236 | + [1, 1024, 8192, 28672], |
| 237 | + [1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance |
| 238 | + [1, 4096, 8192, 16384], |
| 239 | + [1, 8192, 1024, 16384], |
| 240 | + [1, 8192, 4096, 16384], |
| 241 | + [1, 16384, 1024, 8192], |
| 242 | + [1, 16384, 4096, 8192], |
| 243 | + [1, 16384, 8192, 1024], |
| 244 | + [1, 16384, 8192, 4096], |
| 245 | + [4, 32768, 128, 4096], |
| 246 | + [4, 32768, 4096, 128], |
| 247 | + [32, 4096, 128, 4096], |
| 248 | + [4096, 8, 128, 16384], |
| 249 | + [4096, 8, 16384, 128], |
| 250 | +] |
| 251 | + |
| 252 | +DEVICE_NAME = torch.xpu.get_device_name() |
| 253 | +DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory |
| 254 | + |
| 255 | + |
| 256 | +def is_enough_memory(x_val): |
| 257 | + # x_val: (B, M, N, K) |
| 258 | + B, M, N, K = x_val |
| 259 | + # a: (B, M, K) bfloat16 |
| 260 | + # b: (B, N, K) bfloat16 |
| 261 | + # c: (B, M, N) float32 |
| 262 | + # pytorch reference: (B, M, N) float32 |
| 263 | + required_memory = B * M * K * 2 + B * N * K * 2 + 2 * B * M * N * 4 |
| 264 | + enough_memory = required_memory < DEVICE_TOTAL_MEMORY |
| 265 | + if not enough_memory: |
| 266 | + print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}") |
| 267 | + return enough_memory |
| 268 | + |
| 269 | + |
| 270 | +X_VALS = [x_val for x_val in X_VALS if is_enough_memory(x_val)] |
| 271 | + |
| 272 | + |
| 273 | +# Benchmark Performance |
| 274 | +@benchmark_suit.perf_report( |
| 275 | + benchmark_suit.Benchmark( |
| 276 | + # argument names to use as an x-axis for the plot |
| 277 | + x_names=['B', 'M', 'N', 'K'], |
| 278 | + # different possible values for `x_name` |
| 279 | + x_vals=X_VALS, |
| 280 | + line_arg='provider', |
| 281 | + # argument name whose value corresponds to a different line in the plot |
| 282 | + # possible values for `line_arg`` |
| 283 | + line_vals=['triton', 'onednn'] + (['xetla'] if use_xetla else []), |
| 284 | + # label name for the lines |
| 285 | + line_names=['Triton', 'OneDNN'] + (['XeTLA'] if use_xetla else []), |
| 286 | + # line styles |
| 287 | + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
| 288 | + ylabel=['GB/s', 'TFlops'], # label name for the y-axis |
| 289 | + plot_name='matmul-tensor-of-ptr-performance', |
| 290 | + # name for the plot. Used also as a file name for saving the plot. |
| 291 | + args={}, |
| 292 | + )) |
| 293 | +def benchmark(B, M, N, K, provider): |
| 294 | + a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B) |
| 295 | + |
| 296 | + torch.manual_seed(0) |
| 297 | + a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16) |
| 298 | + b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16) |
| 299 | + |
| 300 | + quantiles = [0.5, 0.0, 1.0] |
| 301 | + |
| 302 | + torch_a = a |
| 303 | + if TRANSPOSE_A: |
| 304 | + torch_a = torch.transpose(torch_a, -2, -1) |
| 305 | + |
| 306 | + torch_b = b |
| 307 | + if TRANSPOSE_B: |
| 308 | + torch_b = torch.transpose(torch_b, -2, -1) |
| 309 | + |
| 310 | + if provider == 'onednn': |
| 311 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, |
| 312 | + n_repeat=10, quantiles=quantiles) |
| 313 | + elif provider == 'triton': |
| 314 | + assert len(a.shape) == len(b.shape), 'Incompatible sizes' |
| 315 | + if len(a.shape) == 3: |
| 316 | + c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) |
| 317 | + else: |
| 318 | + assert len(a.shape) == 2, 'Expecting shape of length 2' |
| 319 | + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) |
| 320 | + triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B) |
| 321 | + torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32) |
| 322 | + rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 |
| 323 | + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch') |
| 324 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, |
| 325 | + quantiles=quantiles) |
| 326 | + elif provider == 'xetla': |
| 327 | + if B == 1: |
| 328 | + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) |
| 329 | + cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32) |
| 330 | + else: |
| 331 | + c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) |
| 332 | + cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32) |
| 333 | + name = f'gemm_shape_{B}_{M}_{K}_{N}' |
| 334 | + # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get |
| 335 | + # better performance. |
| 336 | + if (B, M, N, K) == (1, 3072, 3072, 4096): |
| 337 | + name = 'gemm_streamk_shape_3072_4096_3072' |
| 338 | + func = getattr(xetla_kernel, name) |
| 339 | + |
| 340 | + def xetla_func_with_acc_allocation(): |
| 341 | + # allocating `acc` matrix on every function call, to be as similar as |
| 342 | + # possible to the triton kernel, which also does this on every call. |
| 343 | + if B == 1: |
| 344 | + acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) |
| 345 | + else: |
| 346 | + acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) |
| 347 | + return func(a, b, c, acc, cnt) |
| 348 | + |
| 349 | + xetla_fn = xetla_func_with_acc_allocation |
| 350 | + torch_fn = lambda: torch.matmul(a, b).to(torch.float32) |
| 351 | + |
| 352 | + # benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch') |
| 353 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, |
| 354 | + quantiles=quantiles) |
| 355 | + else: |
| 356 | + raise NotImplementedError(f'Unsupported provider {provider}') |
| 357 | + |
| 358 | + tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3) |
| 359 | + gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3) |
| 360 | + |
| 361 | + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv |
| 362 | + |
| 363 | + |
| 364 | +if __name__ == '__main__': |
| 365 | + benchmark.run(show_plots=False, print_data=True) |
0 commit comments