Skip to content

Commit 4fc1bd3

Browse files
committed
add unit and gpu test for sparse attention
Signed-off-by: Kai Xu <[email protected]>
1 parent 3e5837c commit 4fc1bd3

File tree

11 files changed

+1901
-0
lines changed

11 files changed

+1901
-0
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Common utilities for sparse attention testing."""
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
import modelopt.torch.opt as mto
22+
import modelopt.torch.sparsity.attention_sparsity as sparse_attn
23+
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
24+
25+
26+
# Test models for sparse attention
27+
class SimpleAttentionModel(nn.Module):
28+
"""Simple attention model for testing."""
29+
30+
def __init__(self, hidden_size=256, num_heads=8):
31+
super().__init__()
32+
self.hidden_size = hidden_size
33+
self.num_heads = num_heads
34+
self.attention = nn.MultiheadAttention(
35+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True
36+
)
37+
self.fc = nn.Linear(hidden_size, hidden_size)
38+
39+
def forward(self, x):
40+
attn_output, _ = self.attention(x, x, x, need_weights=False)
41+
return self.fc(attn_output)
42+
43+
@classmethod
44+
def get_input(cls, hidden_size=256, seq_len=10, batch_size=2):
45+
"""Get input tensor for testing."""
46+
return torch.randn(batch_size, seq_len, hidden_size)
47+
48+
49+
class SimpleTransformerEncoderLayer(nn.Module):
50+
"""Simple TransformerEncoderLayer wrapper for testing."""
51+
52+
def __init__(self, d_model=128, nhead=4, dim_feedforward=256):
53+
super().__init__()
54+
self.d_model = d_model
55+
self.nhead = nhead
56+
self.layer = nn.TransformerEncoderLayer(
57+
d_model=d_model,
58+
nhead=nhead,
59+
dim_feedforward=dim_feedforward,
60+
batch_first=True,
61+
)
62+
63+
def forward(self, x):
64+
return self.layer(x)
65+
66+
@classmethod
67+
def get_input(cls, d_model=128, seq_len=20, batch_size=2):
68+
"""Get input tensor for testing."""
69+
return torch.randn(batch_size, seq_len, d_model)
70+
71+
72+
class SimpleTransformerEncoder(nn.Module):
73+
"""Simple TransformerEncoder wrapper for testing."""
74+
75+
def __init__(self, d_model=128, nhead=4, num_layers=2):
76+
super().__init__()
77+
self.d_model = d_model
78+
self.nhead = nhead
79+
self.encoder = nn.TransformerEncoder(
80+
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
81+
num_layers=num_layers,
82+
)
83+
84+
def forward(self, x):
85+
return self.encoder(x)
86+
87+
@classmethod
88+
def get_input(cls, d_model=128, seq_len=10, batch_size=2):
89+
"""Get input tensor for testing."""
90+
return torch.randn(batch_size, seq_len, d_model)
91+
92+
93+
# Test configurations
94+
FLASH_SOFTMAX_SKIP_DEFAULT_CFG = {
95+
"method": "flash_softmax_skip",
96+
"sparse_cfg": {"*attention*": {"threshold": 1e-4, "br": 128, "bc": 128, "enable": True}},
97+
}
98+
99+
FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG = {
100+
"method": "flash_softmax_skip",
101+
"sparse_cfg": {
102+
"*attention*": {
103+
"threshold": {"prefill": 1e-3, "decode": 1e-5},
104+
"br": 128,
105+
"bc": 128,
106+
"enable": True,
107+
}
108+
},
109+
}
110+
111+
FLASH_SOFTMAX_SKIP_STATS_CFG = {
112+
"method": "flash_softmax_skip",
113+
"collect_stats": True,
114+
"sparse_cfg": {
115+
"*attention*": {
116+
"threshold": 1e-4,
117+
"br": 128,
118+
"bc": 128,
119+
"collect_stats": True,
120+
"enable": True,
121+
}
122+
},
123+
}
124+
125+
FLASH_SOFTMAX_SKIP_CALIBRATION_CFG = {
126+
"method": "flash_softmax_skip",
127+
"sparse_cfg": {
128+
"*attention*": {
129+
"br": 128,
130+
"bc": 128,
131+
"enable": True,
132+
"calibration": {
133+
"target_sparse_ratio": 0.5,
134+
"samples": 6,
135+
"max_seqlen": 1024,
136+
},
137+
}
138+
},
139+
}
140+
141+
142+
def get_test_configs():
143+
"""Get test configurations for parameterized tests.
144+
145+
Note: Calibration config excluded (requires GPU and real tokenizers).
146+
"""
147+
return [FLASH_SOFTMAX_SKIP_DEFAULT_CFG, FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG]
148+
149+
150+
def sparsify_model_and_forward(model, config, calib_data):
151+
"""Apply sparse attention and run forward passes.
152+
153+
Args:
154+
model: Model to sparsify
155+
config: Sparse attention configuration
156+
calib_data: List of calibration data tensors
157+
158+
Returns:
159+
Sparsified model
160+
"""
161+
162+
def forward_loop(model):
163+
for batch in calib_data:
164+
model(batch)
165+
166+
# Apply sparse attention
167+
model = sparse_attn.sparsify(model, config, forward_loop=forward_loop)
168+
169+
# Verify sparse attention modules were inserted
170+
assert any(isinstance(m, SparseAttentionModule) for m in model.modules()), (
171+
"No sparse attention modules found"
172+
)
173+
174+
# Test forward passes
175+
model.eval()
176+
with torch.no_grad():
177+
for batch in calib_data:
178+
output = model(batch)
179+
assert not torch.isnan(output).any(), "NaN in output"
180+
assert output is not None, "Output is None"
181+
182+
return model
183+
184+
185+
def save_restore_test(model_cls, device, sparse_config):
186+
"""Test save and restore of sparse attention state.
187+
188+
Args:
189+
model_cls: Model class to test
190+
device: Device to run on ('cpu' or 'cuda')
191+
sparse_config: Sparse attention configuration
192+
"""
193+
# Create and sparsify reference model
194+
model_sparse = model_cls().to(device)
195+
calib_data = [model_sparse.get_input().to(device) for _ in range(2)]
196+
197+
sparsify_model_and_forward(model_sparse, sparse_config, calib_data)
198+
199+
# Save state
200+
state_dict = mto.modelopt_state(model_sparse)
201+
202+
# Restore to new model
203+
model_restored = model_cls().to(device)
204+
mto.restore_from_modelopt_state(model_restored, state_dict)
205+
model_restored.load_state_dict(model_sparse.state_dict())
206+
207+
# Verify outputs match
208+
test_input = calib_data[0]
209+
model_sparse.eval()
210+
model_restored.eval()
211+
212+
with torch.no_grad():
213+
output_sparse = model_sparse(test_input)
214+
output_restored = model_restored(test_input)
215+
216+
assert torch.allclose(output_sparse, output_restored, atol=1e-6), (
217+
"Restored model output doesn't match original"
218+
)

0 commit comments

Comments
 (0)