|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | + |
| 6 | +from utils import QUANTILES |
| 7 | +from utils import SingleBenchmarkRunInput |
| 8 | +from utils import SingleBenchmarkRunOutput |
| 9 | +from utils import _test_memory |
| 10 | +from utils import parse_benchmark_script_args |
| 11 | +from utils import run_benchmarks |
| 12 | + |
| 13 | +from liger_kernel.transformers.fused_neighborhood_attention import LigerFusedNeighborhoodAttention |
| 14 | +from liger_kernel.utils import infer_device |
| 15 | + |
| 16 | +device = infer_device() |
| 17 | + |
| 18 | + |
| 19 | +class TorchNeighborhoodAttention(torch.nn.Module): |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + hidden_size: int, |
| 23 | + num_heads: int, |
| 24 | + kernel_size: int = 7, |
| 25 | + dilation: int = 1, |
| 26 | + bias: bool = True, |
| 27 | + dropout: float = 0.0, |
| 28 | + scale: float = None, |
| 29 | + ): |
| 30 | + super().__init__() |
| 31 | + |
| 32 | + if hidden_size % num_heads != 0: |
| 33 | + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})") |
| 34 | + |
| 35 | + self.hidden_size = hidden_size |
| 36 | + self.num_heads = num_heads |
| 37 | + self.head_dim = hidden_size // num_heads |
| 38 | + self.kernel_size = kernel_size |
| 39 | + self.dilation = dilation |
| 40 | + self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim) |
| 41 | + |
| 42 | + self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) |
| 43 | + self.k_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) |
| 44 | + self.v_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) |
| 45 | + self.out_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) |
| 46 | + |
| 47 | + if dropout > 0.0: |
| 48 | + self.dropout = torch.nn.Dropout(dropout) |
| 49 | + else: |
| 50 | + self.dropout = None |
| 51 | + |
| 52 | + def _create_neighborhood_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| 53 | + mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool) |
| 54 | + half_kernel = self.kernel_size // 2 |
| 55 | + |
| 56 | + for i in range(seq_len): |
| 57 | + start = max(0, i - half_kernel * self.dilation) |
| 58 | + end = min(seq_len, i + half_kernel * self.dilation + 1) |
| 59 | + |
| 60 | + for j in range(start, end): |
| 61 | + if self.dilation == 1 or (j - i) % self.dilation == 0: |
| 62 | + mask[i, j] = True |
| 63 | + |
| 64 | + return mask |
| 65 | + |
| 66 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 67 | + batch_size, seq_len, hidden_size = hidden_states.shape |
| 68 | + |
| 69 | + query = self.q_proj(hidden_states) |
| 70 | + key = self.k_proj(hidden_states) |
| 71 | + value = self.v_proj(hidden_states) |
| 72 | + |
| 73 | + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 74 | + key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 75 | + value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 76 | + |
| 77 | + scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale |
| 78 | + |
| 79 | + mask = self._create_neighborhood_mask(seq_len, hidden_states.device) |
| 80 | + scores = scores.masked_fill(~mask, float("-inf")) |
| 81 | + |
| 82 | + attn_weights = torch.softmax(scores, dim=-1) |
| 83 | + |
| 84 | + if self.dropout is not None: |
| 85 | + attn_weights = self.dropout(attn_weights) |
| 86 | + |
| 87 | + attn_output = torch.matmul(attn_weights, value) |
| 88 | + |
| 89 | + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) |
| 90 | + |
| 91 | + output = self.out_proj(attn_output) |
| 92 | + |
| 93 | + return output |
| 94 | + |
| 95 | + |
| 96 | +def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 97 | + seq_len = input.x |
| 98 | + provider = input.kernel_provider |
| 99 | + mode = input.kernel_operation_mode |
| 100 | + |
| 101 | + extra_benchmark_config = input.extra_benchmark_config |
| 102 | + batch_size = extra_benchmark_config["batch_size"] |
| 103 | + hidden_size = extra_benchmark_config["hidden_size"] |
| 104 | + num_heads = extra_benchmark_config["num_heads"] |
| 105 | + kernel_size = extra_benchmark_config["kernel_size"] |
| 106 | + dilation = extra_benchmark_config["dilation"] |
| 107 | + bias = extra_benchmark_config["bias"] |
| 108 | + dtype = extra_benchmark_config["dtype"] |
| 109 | + |
| 110 | + x_shape = (batch_size, seq_len, hidden_size) |
| 111 | + |
| 112 | + liger_attn = ( |
| 113 | + LigerFusedNeighborhoodAttention( |
| 114 | + hidden_size=hidden_size, |
| 115 | + num_heads=num_heads, |
| 116 | + kernel_size=kernel_size, |
| 117 | + dilation=dilation, |
| 118 | + bias=bias, |
| 119 | + dropout=0.0, |
| 120 | + ) |
| 121 | + .to(device) |
| 122 | + .to(dtype) |
| 123 | + ) |
| 124 | + |
| 125 | + torch_attn = ( |
| 126 | + TorchNeighborhoodAttention( |
| 127 | + hidden_size=hidden_size, |
| 128 | + num_heads=num_heads, |
| 129 | + kernel_size=kernel_size, |
| 130 | + dilation=dilation, |
| 131 | + bias=bias, |
| 132 | + dropout=0.0, |
| 133 | + ) |
| 134 | + .to(device) |
| 135 | + .to(dtype) |
| 136 | + ) |
| 137 | + |
| 138 | + with torch.no_grad(): |
| 139 | + torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight) |
| 140 | + torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight) |
| 141 | + torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight) |
| 142 | + torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight) |
| 143 | + |
| 144 | + if bias: |
| 145 | + torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias) |
| 146 | + torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias) |
| 147 | + torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias) |
| 148 | + torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias) |
| 149 | + |
| 150 | + x = torch.randn(x_shape, dtype=dtype, device=device) |
| 151 | + dy = torch.randn_like(x) |
| 152 | + x.requires_grad_(True) |
| 153 | + |
| 154 | + def fwd(): |
| 155 | + if provider == "liger": |
| 156 | + return liger_attn(x) |
| 157 | + elif provider == "torch": |
| 158 | + return torch_attn(x) |
| 159 | + |
| 160 | + print(f"Starting Warmup for input size: {x_shape}") |
| 161 | + _ = fwd() |
| 162 | + if mode in ("backward", "full"): |
| 163 | + y = _ |
| 164 | + y.backward(dy, retain_graph=True) |
| 165 | + print("Done Warmup") |
| 166 | + |
| 167 | + if mode == "forward": |
| 168 | + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) |
| 169 | + elif mode == "backward": |
| 170 | + y = fwd() |
| 171 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 172 | + lambda: y.backward(dy, retain_graph=True), |
| 173 | + grad_to_none=[x], |
| 174 | + rep=100, |
| 175 | + quantiles=QUANTILES, |
| 176 | + ) |
| 177 | + elif mode == "full": |
| 178 | + |
| 179 | + def full(): |
| 180 | + y = fwd() |
| 181 | + y.backward(dy, retain_graph=True) |
| 182 | + |
| 183 | + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) |
| 184 | + |
| 185 | + return SingleBenchmarkRunOutput( |
| 186 | + y_20=ms_20, |
| 187 | + y_50=ms_50, |
| 188 | + y_80=ms_80, |
| 189 | + ) |
| 190 | + |
| 191 | + |
| 192 | +def bench_memory_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 193 | + seq_len = input.x |
| 194 | + provider = input.kernel_provider |
| 195 | + |
| 196 | + extra_benchmark_config = input.extra_benchmark_config |
| 197 | + batch_size = extra_benchmark_config["batch_size"] |
| 198 | + hidden_size = extra_benchmark_config["hidden_size"] |
| 199 | + num_heads = extra_benchmark_config["num_heads"] |
| 200 | + kernel_size = extra_benchmark_config["kernel_size"] |
| 201 | + dilation = extra_benchmark_config["dilation"] |
| 202 | + bias = extra_benchmark_config["bias"] |
| 203 | + dtype = extra_benchmark_config["dtype"] |
| 204 | + |
| 205 | + x_shape = (batch_size, seq_len, hidden_size) |
| 206 | + |
| 207 | + liger_attn = ( |
| 208 | + LigerFusedNeighborhoodAttention( |
| 209 | + hidden_size=hidden_size, |
| 210 | + num_heads=num_heads, |
| 211 | + kernel_size=kernel_size, |
| 212 | + dilation=dilation, |
| 213 | + bias=bias, |
| 214 | + dropout=0.0, |
| 215 | + ) |
| 216 | + .to(device) |
| 217 | + .to(dtype) |
| 218 | + ) |
| 219 | + |
| 220 | + torch_attn = ( |
| 221 | + TorchNeighborhoodAttention( |
| 222 | + hidden_size=hidden_size, |
| 223 | + num_heads=num_heads, |
| 224 | + kernel_size=kernel_size, |
| 225 | + dilation=dilation, |
| 226 | + bias=bias, |
| 227 | + dropout=0.0, |
| 228 | + ) |
| 229 | + .to(device) |
| 230 | + .to(dtype) |
| 231 | + ) |
| 232 | + |
| 233 | + with torch.no_grad(): |
| 234 | + torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight) |
| 235 | + torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight) |
| 236 | + torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight) |
| 237 | + torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight) |
| 238 | + |
| 239 | + if bias: |
| 240 | + torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias) |
| 241 | + torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias) |
| 242 | + torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias) |
| 243 | + torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias) |
| 244 | + |
| 245 | + x = torch.randn(x_shape, dtype=dtype, device=device) |
| 246 | + dy = torch.randn_like(x) |
| 247 | + x.requires_grad_(True) |
| 248 | + |
| 249 | + def fwd(): |
| 250 | + if provider == "liger": |
| 251 | + return liger_attn(x) |
| 252 | + elif provider == "torch": |
| 253 | + return torch_attn(x) |
| 254 | + |
| 255 | + def full(): |
| 256 | + y = fwd() |
| 257 | + y.backward(dy, retain_graph=True) |
| 258 | + |
| 259 | + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) |
| 260 | + |
| 261 | + return SingleBenchmarkRunOutput( |
| 262 | + y_20=mem_20, |
| 263 | + y_50=mem_50, |
| 264 | + y_80=mem_80, |
| 265 | + ) |
| 266 | + |
| 267 | + |
| 268 | +if __name__ == "__main__": |
| 269 | + args = parse_benchmark_script_args() |
| 270 | + |
| 271 | + common_configs = { |
| 272 | + "kernel_name": "fused_neighborhood_attention", |
| 273 | + "x_name": "seq_len", |
| 274 | + "x_label": "sequence length", |
| 275 | + "x_values": [2**i for i in range(6, 13)], |
| 276 | + "kernel_providers": ["liger", "torch"], |
| 277 | + "extra_benchmark_configs": [ |
| 278 | + { |
| 279 | + "batch_size": 2, |
| 280 | + "hidden_size": 512, |
| 281 | + "num_heads": 8, |
| 282 | + "kernel_size": 7, |
| 283 | + "dilation": 1, |
| 284 | + "bias": True, |
| 285 | + "dtype": torch.float32, |
| 286 | + }, |
| 287 | + { |
| 288 | + "batch_size": 4, |
| 289 | + "hidden_size": 768, |
| 290 | + "num_heads": 12, |
| 291 | + "kernel_size": 7, |
| 292 | + "dilation": 1, |
| 293 | + "bias": True, |
| 294 | + "dtype": torch.float32, |
| 295 | + }, |
| 296 | + { |
| 297 | + "batch_size": 2, |
| 298 | + "hidden_size": 1024, |
| 299 | + "num_heads": 16, |
| 300 | + "kernel_size": 9, |
| 301 | + "dilation": 1, |
| 302 | + "bias": True, |
| 303 | + "dtype": torch.float32, |
| 304 | + }, |
| 305 | + { |
| 306 | + "batch_size": 2, |
| 307 | + "hidden_size": 512, |
| 308 | + "num_heads": 8, |
| 309 | + "kernel_size": 7, |
| 310 | + "dilation": 2, |
| 311 | + "bias": True, |
| 312 | + "dtype": torch.float32, |
| 313 | + }, |
| 314 | + { |
| 315 | + "batch_size": 2, |
| 316 | + "hidden_size": 512, |
| 317 | + "num_heads": 8, |
| 318 | + "kernel_size": 7, |
| 319 | + "dilation": 1, |
| 320 | + "bias": True, |
| 321 | + "dtype": torch.bfloat16, |
| 322 | + }, |
| 323 | + { |
| 324 | + "batch_size": 4, |
| 325 | + "hidden_size": 768, |
| 326 | + "num_heads": 12, |
| 327 | + "kernel_size": 7, |
| 328 | + "dilation": 1, |
| 329 | + "bias": True, |
| 330 | + "dtype": torch.bfloat16, |
| 331 | + }, |
| 332 | + { |
| 333 | + "batch_size": 2, |
| 334 | + "hidden_size": 1024, |
| 335 | + "num_heads": 16, |
| 336 | + "kernel_size": 9, |
| 337 | + "dilation": 1, |
| 338 | + "bias": True, |
| 339 | + "dtype": torch.bfloat16, |
| 340 | + }, |
| 341 | + { |
| 342 | + "batch_size": 2, |
| 343 | + "hidden_size": 512, |
| 344 | + "num_heads": 8, |
| 345 | + "kernel_size": 7, |
| 346 | + "dilation": 2, |
| 347 | + "bias": True, |
| 348 | + "dtype": torch.bfloat16, |
| 349 | + }, |
| 350 | + ], |
| 351 | + } |
| 352 | + |
| 353 | + run_benchmarks( |
| 354 | + bench_test_fn=bench_speed_fused_neighborhood_attention, |
| 355 | + kernel_operation_modes=["forward", "full", "backward"], |
| 356 | + metric_name="speed", |
| 357 | + metric_unit="ms", |
| 358 | + **common_configs, |
| 359 | + ) |
| 360 | + |
| 361 | + run_benchmarks( |
| 362 | + bench_test_fn=bench_memory_fused_neighborhood_attention, |
| 363 | + kernel_operation_modes=["full"], |
| 364 | + metric_name="memory", |
| 365 | + metric_unit="MB", |
| 366 | + **common_configs, |
| 367 | + ) |
0 commit comments