Skip to content

Commit a53675d

Browse files
[Operator] Fused Neighborhood Attention (#732)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> #733 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Tested Attention Layer and Attention module implementation for FusedNeighborhoodAttention <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: 3090 & H100 SXM5 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ---------
1 parent 4d11fee commit a53675d

File tree

6 files changed

+2671
-0
lines changed

6 files changed

+2671
-0
lines changed

benchmark/data/all_benchmark_data.csv

Lines changed: 448 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
import math
2+
3+
import torch
4+
import triton
5+
6+
from utils import QUANTILES
7+
from utils import SingleBenchmarkRunInput
8+
from utils import SingleBenchmarkRunOutput
9+
from utils import _test_memory
10+
from utils import parse_benchmark_script_args
11+
from utils import run_benchmarks
12+
13+
from liger_kernel.transformers.fused_neighborhood_attention import LigerFusedNeighborhoodAttention
14+
from liger_kernel.utils import infer_device
15+
16+
device = infer_device()
17+
18+
19+
class TorchNeighborhoodAttention(torch.nn.Module):
20+
def __init__(
21+
self,
22+
hidden_size: int,
23+
num_heads: int,
24+
kernel_size: int = 7,
25+
dilation: int = 1,
26+
bias: bool = True,
27+
dropout: float = 0.0,
28+
scale: float = None,
29+
):
30+
super().__init__()
31+
32+
if hidden_size % num_heads != 0:
33+
raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
34+
35+
self.hidden_size = hidden_size
36+
self.num_heads = num_heads
37+
self.head_dim = hidden_size // num_heads
38+
self.kernel_size = kernel_size
39+
self.dilation = dilation
40+
self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
41+
42+
self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
43+
self.k_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
44+
self.v_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
45+
self.out_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
46+
47+
if dropout > 0.0:
48+
self.dropout = torch.nn.Dropout(dropout)
49+
else:
50+
self.dropout = None
51+
52+
def _create_neighborhood_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
53+
mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool)
54+
half_kernel = self.kernel_size // 2
55+
56+
for i in range(seq_len):
57+
start = max(0, i - half_kernel * self.dilation)
58+
end = min(seq_len, i + half_kernel * self.dilation + 1)
59+
60+
for j in range(start, end):
61+
if self.dilation == 1 or (j - i) % self.dilation == 0:
62+
mask[i, j] = True
63+
64+
return mask
65+
66+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
67+
batch_size, seq_len, hidden_size = hidden_states.shape
68+
69+
query = self.q_proj(hidden_states)
70+
key = self.k_proj(hidden_states)
71+
value = self.v_proj(hidden_states)
72+
73+
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
74+
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
75+
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
76+
77+
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
78+
79+
mask = self._create_neighborhood_mask(seq_len, hidden_states.device)
80+
scores = scores.masked_fill(~mask, float("-inf"))
81+
82+
attn_weights = torch.softmax(scores, dim=-1)
83+
84+
if self.dropout is not None:
85+
attn_weights = self.dropout(attn_weights)
86+
87+
attn_output = torch.matmul(attn_weights, value)
88+
89+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
90+
91+
output = self.out_proj(attn_output)
92+
93+
return output
94+
95+
96+
def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
97+
seq_len = input.x
98+
provider = input.kernel_provider
99+
mode = input.kernel_operation_mode
100+
101+
extra_benchmark_config = input.extra_benchmark_config
102+
batch_size = extra_benchmark_config["batch_size"]
103+
hidden_size = extra_benchmark_config["hidden_size"]
104+
num_heads = extra_benchmark_config["num_heads"]
105+
kernel_size = extra_benchmark_config["kernel_size"]
106+
dilation = extra_benchmark_config["dilation"]
107+
bias = extra_benchmark_config["bias"]
108+
dtype = extra_benchmark_config["dtype"]
109+
110+
x_shape = (batch_size, seq_len, hidden_size)
111+
112+
liger_attn = (
113+
LigerFusedNeighborhoodAttention(
114+
hidden_size=hidden_size,
115+
num_heads=num_heads,
116+
kernel_size=kernel_size,
117+
dilation=dilation,
118+
bias=bias,
119+
dropout=0.0,
120+
)
121+
.to(device)
122+
.to(dtype)
123+
)
124+
125+
torch_attn = (
126+
TorchNeighborhoodAttention(
127+
hidden_size=hidden_size,
128+
num_heads=num_heads,
129+
kernel_size=kernel_size,
130+
dilation=dilation,
131+
bias=bias,
132+
dropout=0.0,
133+
)
134+
.to(device)
135+
.to(dtype)
136+
)
137+
138+
with torch.no_grad():
139+
torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight)
140+
torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight)
141+
torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight)
142+
torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight)
143+
144+
if bias:
145+
torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias)
146+
torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias)
147+
torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias)
148+
torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias)
149+
150+
x = torch.randn(x_shape, dtype=dtype, device=device)
151+
dy = torch.randn_like(x)
152+
x.requires_grad_(True)
153+
154+
def fwd():
155+
if provider == "liger":
156+
return liger_attn(x)
157+
elif provider == "torch":
158+
return torch_attn(x)
159+
160+
print(f"Starting Warmup for input size: {x_shape}")
161+
_ = fwd()
162+
if mode in ("backward", "full"):
163+
y = _
164+
y.backward(dy, retain_graph=True)
165+
print("Done Warmup")
166+
167+
if mode == "forward":
168+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES)
169+
elif mode == "backward":
170+
y = fwd()
171+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
172+
lambda: y.backward(dy, retain_graph=True),
173+
grad_to_none=[x],
174+
rep=100,
175+
quantiles=QUANTILES,
176+
)
177+
elif mode == "full":
178+
179+
def full():
180+
y = fwd()
181+
y.backward(dy, retain_graph=True)
182+
183+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES)
184+
185+
return SingleBenchmarkRunOutput(
186+
y_20=ms_20,
187+
y_50=ms_50,
188+
y_80=ms_80,
189+
)
190+
191+
192+
def bench_memory_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
193+
seq_len = input.x
194+
provider = input.kernel_provider
195+
196+
extra_benchmark_config = input.extra_benchmark_config
197+
batch_size = extra_benchmark_config["batch_size"]
198+
hidden_size = extra_benchmark_config["hidden_size"]
199+
num_heads = extra_benchmark_config["num_heads"]
200+
kernel_size = extra_benchmark_config["kernel_size"]
201+
dilation = extra_benchmark_config["dilation"]
202+
bias = extra_benchmark_config["bias"]
203+
dtype = extra_benchmark_config["dtype"]
204+
205+
x_shape = (batch_size, seq_len, hidden_size)
206+
207+
liger_attn = (
208+
LigerFusedNeighborhoodAttention(
209+
hidden_size=hidden_size,
210+
num_heads=num_heads,
211+
kernel_size=kernel_size,
212+
dilation=dilation,
213+
bias=bias,
214+
dropout=0.0,
215+
)
216+
.to(device)
217+
.to(dtype)
218+
)
219+
220+
torch_attn = (
221+
TorchNeighborhoodAttention(
222+
hidden_size=hidden_size,
223+
num_heads=num_heads,
224+
kernel_size=kernel_size,
225+
dilation=dilation,
226+
bias=bias,
227+
dropout=0.0,
228+
)
229+
.to(device)
230+
.to(dtype)
231+
)
232+
233+
with torch.no_grad():
234+
torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight)
235+
torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight)
236+
torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight)
237+
torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight)
238+
239+
if bias:
240+
torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias)
241+
torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias)
242+
torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias)
243+
torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias)
244+
245+
x = torch.randn(x_shape, dtype=dtype, device=device)
246+
dy = torch.randn_like(x)
247+
x.requires_grad_(True)
248+
249+
def fwd():
250+
if provider == "liger":
251+
return liger_attn(x)
252+
elif provider == "torch":
253+
return torch_attn(x)
254+
255+
def full():
256+
y = fwd()
257+
y.backward(dy, retain_graph=True)
258+
259+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
260+
261+
return SingleBenchmarkRunOutput(
262+
y_20=mem_20,
263+
y_50=mem_50,
264+
y_80=mem_80,
265+
)
266+
267+
268+
if __name__ == "__main__":
269+
args = parse_benchmark_script_args()
270+
271+
common_configs = {
272+
"kernel_name": "fused_neighborhood_attention",
273+
"x_name": "seq_len",
274+
"x_label": "sequence length",
275+
"x_values": [2**i for i in range(6, 13)],
276+
"kernel_providers": ["liger", "torch"],
277+
"extra_benchmark_configs": [
278+
{
279+
"batch_size": 2,
280+
"hidden_size": 512,
281+
"num_heads": 8,
282+
"kernel_size": 7,
283+
"dilation": 1,
284+
"bias": True,
285+
"dtype": torch.float32,
286+
},
287+
{
288+
"batch_size": 4,
289+
"hidden_size": 768,
290+
"num_heads": 12,
291+
"kernel_size": 7,
292+
"dilation": 1,
293+
"bias": True,
294+
"dtype": torch.float32,
295+
},
296+
{
297+
"batch_size": 2,
298+
"hidden_size": 1024,
299+
"num_heads": 16,
300+
"kernel_size": 9,
301+
"dilation": 1,
302+
"bias": True,
303+
"dtype": torch.float32,
304+
},
305+
{
306+
"batch_size": 2,
307+
"hidden_size": 512,
308+
"num_heads": 8,
309+
"kernel_size": 7,
310+
"dilation": 2,
311+
"bias": True,
312+
"dtype": torch.float32,
313+
},
314+
{
315+
"batch_size": 2,
316+
"hidden_size": 512,
317+
"num_heads": 8,
318+
"kernel_size": 7,
319+
"dilation": 1,
320+
"bias": True,
321+
"dtype": torch.bfloat16,
322+
},
323+
{
324+
"batch_size": 4,
325+
"hidden_size": 768,
326+
"num_heads": 12,
327+
"kernel_size": 7,
328+
"dilation": 1,
329+
"bias": True,
330+
"dtype": torch.bfloat16,
331+
},
332+
{
333+
"batch_size": 2,
334+
"hidden_size": 1024,
335+
"num_heads": 16,
336+
"kernel_size": 9,
337+
"dilation": 1,
338+
"bias": True,
339+
"dtype": torch.bfloat16,
340+
},
341+
{
342+
"batch_size": 2,
343+
"hidden_size": 512,
344+
"num_heads": 8,
345+
"kernel_size": 7,
346+
"dilation": 2,
347+
"bias": True,
348+
"dtype": torch.bfloat16,
349+
},
350+
],
351+
}
352+
353+
run_benchmarks(
354+
bench_test_fn=bench_speed_fused_neighborhood_attention,
355+
kernel_operation_modes=["forward", "full", "backward"],
356+
metric_name="speed",
357+
metric_unit="ms",
358+
**common_configs,
359+
)
360+
361+
run_benchmarks(
362+
bench_test_fn=bench_memory_fused_neighborhood_attention,
363+
kernel_operation_modes=["full"],
364+
metric_name="memory",
365+
metric_unit="MB",
366+
**common_configs,
367+
)

0 commit comments

Comments
 (0)