Skip to content

Commit 02182f8

Browse files
committed
Address review feedback
Signed-off-by: Kai Xu <[email protected]>
1 parent 1910fc6 commit 02182f8

File tree

8 files changed

+6
-224
lines changed

8 files changed

+6
-224
lines changed

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ def main(args):
295295
"--backend",
296296
type=str,
297297
default="pytorch",
298-
choices=["pytorch", "triton"],
299-
help="Backend to use for sparse attention computation (default: pytorch)",
298+
choices=["pytorch"],
299+
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
300300
)
301301

302302
# Sequence length arguments

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -299,64 +299,3 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal
299299

300300
if matched:
301301
module.enable()
302-
303-
304-
def print_sparse_attention_summary(model: nn.Module):
305-
"""Print summary of sparse attention modules in the model.
306-
307-
Similar to mtq.print_quant_summary for API consistency.
308-
309-
Args:
310-
model: Model with sparse attention applied
311-
312-
Prints:
313-
- Total sparse attention modules
314-
- Enabled vs disabled count
315-
- Method distribution
316-
- Configuration summary by module
317-
318-
Example:
319-
>>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn
320-
>>> model = sparse_attn.sparsify(model, config)
321-
>>> sparse_attn.print_sparse_attention_summary(model)
322-
"""
323-
sparse_modules = []
324-
for name, module in model.named_modules():
325-
if isinstance(module, SparseAttentionModule):
326-
sparse_modules.append((name, module))
327-
328-
if not sparse_modules:
329-
print("No sparse attention modules found in model")
330-
return
331-
332-
enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled)
333-
disabled_count = len(sparse_modules) - enabled_count
334-
335-
# Count methods
336-
method_counts = {}
337-
for _, module in sparse_modules:
338-
method = getattr(module, "_method", "unknown")
339-
method_counts[method] = method_counts.get(method, 0) + 1
340-
341-
print(f"Total sparse attention modules: {len(sparse_modules)}")
342-
print(f"Enabled: {enabled_count}")
343-
print(f"Disabled: {disabled_count}")
344-
345-
if method_counts:
346-
print("\nMethods:")
347-
for method, count in sorted(method_counts.items()):
348-
print(f"{method}: {count}")
349-
350-
for name, module in sparse_modules:
351-
method = getattr(module, "_method", "unknown")
352-
threshold = getattr(module, "_threshold", "N/A")
353-
354-
# Format threshold nicely
355-
if isinstance(threshold, dict):
356-
threshold_str = str(threshold)
357-
elif isinstance(threshold, float):
358-
threshold_str = f"{threshold:.2e}"
359-
else:
360-
threshold_str = str(threshold)
361-
362-
print(f"{name}: Method: {method}, Threshold: {threshold_str}")

modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515

1616
"""Dynamic sparse attention registration for HuggingFace models."""
1717

18-
import logging
19-
2018
import torch.nn as nn
2119
import transformers
2220

2321
from modelopt.torch.opt.dynamic import DynamicModule
2422

2523
from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry
2624

27-
logger = logging.getLogger(__name__)
28-
2925

3026
class _GenericSparseAttention(SparseAttentionModule):
3127
"""Generic sparse attention that works with any HF attention module.
@@ -94,12 +90,10 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool:
9490
SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention)
9591
attention_types.add(module_type)
9692
registered_count += 1
97-
logger.info(f"Registered {type_name} for sparse attention optimization")
93+
print(f"Registered {type_name} for sparse attention optimization")
9894

9995
if registered_count > 0:
100-
logger.info(
101-
f"Dynamically registered {registered_count} attention module types for sparsity"
102-
)
96+
print(f"Dynamically registered {registered_count} attention module types for sparsity")
10397

10498
return registered_count > 0
10599

tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import pytest
1919
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command
20-
from _test_utils.torch.misc import minimum_gpu
2120

2221

2322
def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs):
@@ -42,7 +41,6 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax",
4241
run_example_command(cmd_parts, "llm_sparsity/attention_sparsity")
4342

4443

45-
@minimum_gpu(1)
4644
@pytest.mark.parametrize("method", ["skip_softmax"])
4745
def test_attention_sparsity(tiny_llama_path, tmp_path, method):
4846
"""Test sparse attention with TinyLlama."""

tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929

3030
import modelopt.torch.sparsity.attention_sparsity as sparse_attn
3131

32-
# Skip all tests if GPU is not available
33-
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
34-
3532

3633
class TestAttentionSparsityGPU:
3734
"""GPU tests for attention sparsity."""

tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
2525
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
2626

27-
# Skip all tests if GPU is not available
28-
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
29-
3027

3128
@pytest.fixture(scope="module")
3229
def tiny_llama_dir(tmp_path_factory):
@@ -35,8 +32,8 @@ def tiny_llama_dir(tmp_path_factory):
3532
tmp_path_factory.mktemp("tiny_llama"),
3633
with_tokenizer=True,
3734
num_hidden_layers=2, # Minimal layers for fast testing
38-
hidden_size=512,
39-
intermediate_size=1024,
35+
hidden_size=32,
36+
intermediate_size=64,
4037
)
4138

4239

tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py

Lines changed: 0 additions & 129 deletions
This file was deleted.

tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from modelopt.torch.sparsity.attention_sparsity.conversion import (
3232
disable_sparse_attention,
3333
enable_sparse_attention,
34-
print_sparse_attention_summary,
3534
)
3635
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
3736

@@ -171,19 +170,6 @@ def test_disable_enable_functions(self):
171170
if isinstance(module, SparseAttentionModule):
172171
assert module.is_enabled
173172

174-
def test_print_sparse_attention_summary(self, capsys):
175-
"""Test print_sparse_attention_summary function."""
176-
model = SimpleAttentionModel()
177-
model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG)
178-
179-
# Print summary
180-
print_sparse_attention_summary(model)
181-
182-
# Capture output
183-
captured = capsys.readouterr()
184-
assert "Total sparse attention modules:" in captured.out
185-
assert "Enabled:" in captured.out
186-
187173
def test_restore_sparse_attention_model(self):
188174
"""Test save/restore via modelopt_state."""
189175
# Create and sparsify original model

0 commit comments

Comments
 (0)