|
2 | 2 | import sys |
3 | 3 |
|
4 | 4 | import torch |
5 | | -import triton |
6 | 5 |
|
7 | | -from utils import QUANTILES |
| 6 | +from benchmark_model_configs import compute_hidden_size_sweep_config |
| 7 | +from benchmark_model_configs import estimate_kernel_peak_memory |
| 8 | +from benchmark_model_configs import get_benchmark_model_config |
8 | 9 | from utils import SingleBenchmarkRunInput |
9 | 10 | from utils import SingleBenchmarkRunOutput |
10 | | -from utils import _test_memory |
11 | 11 | from utils import parse_benchmark_script_args |
12 | 12 | from utils import run_benchmarks |
| 13 | +from utils import run_memory_benchmark |
| 14 | +from utils import run_speed_benchmark |
13 | 15 |
|
14 | 16 | from liger_kernel.utils import infer_device |
15 | 17 |
|
|
18 | 20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) |
19 | 21 |
|
20 | 22 |
|
21 | | -def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 23 | +def _setup_dyt(input: SingleBenchmarkRunInput): |
| 24 | + """Create input tensor and DyT layer from benchmark config.""" |
22 | 25 | from test.transformers.test_dyt import LigerDyT |
23 | 26 | from test.transformers.test_dyt import TorchDyT |
24 | 27 |
|
| 28 | + cfg = input.extra_benchmark_config |
25 | 29 | hidden_size = input.x |
26 | | - provider = input.kernel_provider |
27 | | - mode = input.kernel_operation_mode |
28 | | - extra_benchmark_config = input.extra_benchmark_config |
29 | | - BT = extra_benchmark_config["BT"] |
30 | | - beta = extra_benchmark_config["beta"] |
31 | | - dtype = extra_benchmark_config["dtype"] |
32 | | - |
33 | | - x_shape = (BT, hidden_size) |
34 | | - torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device) |
35 | | - torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device)) |
36 | | - triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device) |
37 | | - |
38 | | - x = torch.randn(x_shape, dtype=dtype, device=device) |
39 | | - dy = torch.randn_like(x) |
40 | | - x.requires_grad_(True) |
41 | | - |
42 | | - def fwd(): |
43 | | - if provider == "liger": |
44 | | - return triton_dyt(x) |
45 | | - elif provider == "torch": |
46 | | - return torch_dyt(x) |
47 | | - elif provider == "torch_compile": |
48 | | - return torch_compile_dyt(x) |
49 | | - |
50 | | - if mode == "forward": |
51 | | - ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) |
52 | | - elif mode == "backward": |
53 | | - y = fwd() |
54 | | - ms_50, ms_20, ms_80 = triton.testing.do_bench( |
55 | | - lambda: y.backward(dy, retain_graph=True), |
56 | | - quantiles=QUANTILES, |
57 | | - grad_to_none=[x], |
58 | | - rep=500, |
59 | | - ) |
60 | | - elif mode == "full": |
61 | | - |
62 | | - def full(): |
63 | | - y = fwd() |
64 | | - y.backward(dy) |
| 30 | + x = torch.randn(cfg["BT"], hidden_size, device=device, dtype=cfg["dtype"], requires_grad=True) |
| 31 | + if input.kernel_provider == "liger": |
| 32 | + layer = LigerDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device) |
| 33 | + elif input.kernel_provider == "torch": |
| 34 | + layer = TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device) |
| 35 | + elif input.kernel_provider == "torch_compile": |
| 36 | + layer = torch.compile(TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device)) |
| 37 | + else: |
| 38 | + raise ValueError(f"Invalid provider: {input.kernel_provider} for DyT") |
| 39 | + return x, layer |
65 | 40 |
|
66 | | - ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) |
67 | 41 |
|
68 | | - return SingleBenchmarkRunOutput( |
69 | | - y_20=ms_20, |
70 | | - y_50=ms_50, |
71 | | - y_80=ms_80, |
72 | | - ) |
| 42 | +def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 43 | + x, layer = _setup_dyt(input) |
| 44 | + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) |
73 | 45 |
|
74 | 46 |
|
75 | 47 | def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
76 | | - from test.transformers.test_dyt import LigerDyT |
77 | | - from test.transformers.test_dyt import TorchDyT |
| 48 | + x, layer = _setup_dyt(input) |
| 49 | + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) |
78 | 50 |
|
79 | | - hidden_size = input.x |
80 | | - provider = input.kernel_provider |
81 | | - extra_benchmark_config = input.extra_benchmark_config |
82 | | - BT = extra_benchmark_config["BT"] |
83 | | - beta = extra_benchmark_config["beta"] |
84 | | - dtype = extra_benchmark_config["dtype"] |
85 | | - |
86 | | - x_shape = (BT, hidden_size) |
87 | | - torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device) |
88 | | - torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device)) |
89 | | - triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device) |
90 | | - |
91 | | - x = torch.randn(x_shape, dtype=dtype, device=device) |
92 | | - dy = torch.randn_like(x) |
93 | | - x.requires_grad_(True) |
94 | | - |
95 | | - def fwd(): |
96 | | - if provider == "liger": |
97 | | - return triton_dyt(x) |
98 | | - elif provider == "torch": |
99 | | - return torch_dyt(x) |
100 | | - elif provider == "torch_compile": |
101 | | - return torch_compile_dyt(x) |
102 | | - |
103 | | - def full(): |
104 | | - y = fwd() |
105 | | - y.backward(dy, retain_graph=True) |
106 | | - |
107 | | - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) |
108 | | - return SingleBenchmarkRunOutput( |
109 | | - y_20=mem_20, |
110 | | - y_50=mem_50, |
111 | | - y_80=mem_80, |
112 | | - ) |
113 | 51 |
|
| 52 | +BT = 4096 |
114 | 53 |
|
115 | 54 | if __name__ == "__main__": |
116 | 55 | args = parse_benchmark_script_args() |
| 56 | + model = get_benchmark_model_config(args.model) |
117 | 57 |
|
118 | 58 | for beta in [False, True]: |
| 59 | + |
| 60 | + def _probe(): |
| 61 | + probe_input = SingleBenchmarkRunInput( |
| 62 | + x=model.hidden_size, |
| 63 | + kernel_provider="torch", |
| 64 | + extra_benchmark_config={"BT": BT, "dtype": model.dtype, "beta": beta}, |
| 65 | + ) |
| 66 | + x, layer = _setup_dyt(probe_input) |
| 67 | + return layer(x) |
| 68 | + |
| 69 | + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) |
| 70 | + sweep_config = compute_hidden_size_sweep_config(model, peak_bytes, bt=BT) |
| 71 | + x_values = [1024 * i for i in range(1, 17) if 1024 * i <= sweep_config.max_hidden_size] or [model.hidden_size] |
| 72 | + |
119 | 73 | common_configs = { |
120 | 74 | "kernel_name": f"dyt_beta={beta}", |
121 | 75 | "x_name": "hidden_size", |
122 | 76 | "x_label": "hidden_size", |
123 | | - "x_values": [1024 * i for i in range(1, 17)], |
| 77 | + "x_values": x_values, |
124 | 78 | "kernel_providers": ["liger", "torch", "torch_compile"], |
125 | | - "extra_benchmark_configs": [{"BT": 4096, "dtype": torch.bfloat16, "beta": beta}], |
| 79 | + "extra_benchmark_configs": [{"BT": sweep_config.bt, "dtype": model.dtype, "beta": beta}], |
126 | 80 | "overwrite": args.overwrite, |
127 | 81 | } |
128 | 82 |
|
129 | 83 | run_benchmarks( |
130 | 84 | bench_test_fn=bench_speed_dyt, |
131 | | - kernel_operation_modes=["forward", "backward", "full"], |
| 85 | + kernel_operation_modes=["full", "forward", "backward"], |
132 | 86 | metric_name="speed", |
133 | 87 | metric_unit="ms", |
134 | 88 | **common_configs, |
135 | 89 | ) |
136 | 90 | run_benchmarks( |
137 | 91 | bench_test_fn=bench_memory_dyt, |
138 | | - kernel_operation_modes=["full"], |
| 92 | + kernel_operation_modes=["full", "forward", "backward"], |
139 | 93 | metric_name="memory", |
140 | 94 | metric_unit="MB", |
141 | 95 | **common_configs, |
|
0 commit comments