|
| 1 | +from dataclasses import dataclass |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +plt.rcParams["figure.figsize"] = [12, 6] |
| 7 | +plt.rcParams["figure.dpi"] = 600 |
| 8 | +plt.rcParams["font.family"] = "JetBrains Mono" |
| 9 | +plt.rcParams["font.weight"] = "bold" |
| 10 | +plt.rcParams["axes.titleweight"] = "bold" |
| 11 | +plt.rcParams["axes.labelweight"] = "bold" |
| 12 | + |
| 13 | + |
| 14 | +@dataclass |
| 15 | +class KernelInformation: |
| 16 | + name: str |
| 17 | + memory_bound: bool |
| 18 | + compute_bound: bool |
| 19 | + perf_report_path: str |
| 20 | + independent_variable: str |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class CategoryInformation: |
| 25 | + kernels: tuple |
| 26 | + y_label: str |
| 27 | + |
| 28 | + |
| 29 | +kernels = ( |
| 30 | + KernelInformation("add", True, False, "vector-addition-performance.csv", "Length"), |
| 31 | + KernelInformation( |
| 32 | + "softmax", True, False, "softmax-performance.csv", "Number of Columns" |
| 33 | + ), |
| 34 | + KernelInformation( |
| 35 | + "rms_norm", True, False, "rms-norm-performance.csv", "Number of Columns" |
| 36 | + ), |
| 37 | + KernelInformation( |
| 38 | + "matmul", False, True, "matrix-multiplication-performance.csv", "Sizes" |
| 39 | + ), |
| 40 | + KernelInformation( |
| 41 | + "conv2d", False, True, "2d-convolution-performance.csv", "Batch Size" |
| 42 | + ), |
| 43 | + KernelInformation( |
| 44 | + "attention", False, True, "attention-performance.csv", "Sequence Length" |
| 45 | + ), |
| 46 | +) |
| 47 | + |
| 48 | +providers = ("Triton", "NineToothed") |
| 49 | + |
| 50 | +categories = ( |
| 51 | + CategoryInformation( |
| 52 | + tuple(kernel for kernel in kernels if kernel.memory_bound), "GB/s" |
| 53 | + ), |
| 54 | + CategoryInformation( |
| 55 | + tuple(kernel for kernel in kernels if kernel.compute_bound), "TFLOPS" |
| 56 | + ), |
| 57 | +) |
| 58 | + |
| 59 | +num_rows = len(categories) |
| 60 | +num_cols = max(len(category.kernels) for category in categories) |
| 61 | + |
| 62 | +fig, axs = plt.subplots(num_rows, num_cols) |
| 63 | + |
| 64 | +for row, category in enumerate(categories): |
| 65 | + axs[row, 0].set_ylabel(category.y_label) |
| 66 | + |
| 67 | + for col, kernel in enumerate(category.kernels): |
| 68 | + df = pd.read_csv(kernel.perf_report_path) |
| 69 | + ax = axs[row, col] |
| 70 | + |
| 71 | + x = df.iloc[:, 0] |
| 72 | + |
| 73 | + for provider in providers: |
| 74 | + y = df[provider] |
| 75 | + |
| 76 | + ax.plot(x, y, label=provider) |
| 77 | + |
| 78 | + ax.set_title(kernel.name) |
| 79 | + ax.set_xlabel(kernel.independent_variable) |
| 80 | + ax.set_xscale("log", base=2) |
| 81 | + |
| 82 | +fig.legend(providers, loc="upper center", ncols=len(providers)) |
| 83 | +fig.tight_layout() |
| 84 | +fig.subplots_adjust(top=0.9) |
| 85 | + |
| 86 | +plt.show() |
| 87 | +plt.savefig("performance-comparison.png") |
0 commit comments