Skip to content

Commit 396a567

Browse files
[float8] add tests for float8 _auto_filter_for_recipe (#2450)
* add test for float8 _auto_filter_for_recipe * address comments * add to test_everything.sh
1 parent 589a93a commit 396a567

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

test/float8/test_auto_filter.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
import torch.nn as nn
3+
4+
from torchao.float8 import _auto_filter_for_recipe
5+
from torchao.float8.float8_linear_utils import (
6+
_auto_filter_for_rowwise,
7+
_auto_filter_for_tensorwise,
8+
)
9+
10+
11+
@pytest.mark.parametrize(
12+
"recipe_type,module_dims,fqn,filter_fqns,expected",
13+
[
14+
# Tensorwise tests
15+
("tensorwise", (8192, 2048), "valid.layer", [], True),
16+
# FQN matches filter
17+
("tensorwise", (8192, 2048), "skip_layer.linear", ["skip_layer"], False),
18+
# Threshold fail
19+
("tensorwise", (4096, 1024), "valid.layer", [], False),
20+
# Rowwise tests
21+
("rowwise", (4096, 8192), "valid.layer", [], True),
22+
("rowwise", (4096, 8192), "skip_layer.linear", ["skip_layer"], False),
23+
# Combined threshold fail
24+
(
25+
"rowwise",
26+
(2048, 4096),
27+
"valid.layer",
28+
[],
29+
False,
30+
),
31+
],
32+
)
33+
def test_end_to_end_filtering(recipe_type, module_dims, fqn, filter_fqns, expected):
34+
"""Test complete filtering workflow for both recipe types."""
35+
in_features, out_features = module_dims
36+
37+
# Get the filter function
38+
filter_func = _auto_filter_for_recipe(recipe_type, filter_fqns)
39+
40+
# Create test module
41+
test_module = nn.Linear(in_features, out_features)
42+
43+
# Test filtering
44+
result = filter_func(test_module, fqn)
45+
assert result is expected
46+
47+
48+
def test_exact_boundary_dimensions_rowwise():
49+
"""Test exact boundary dimensions for rowwise filtering."""
50+
# Test exact thresholds
51+
module_n_2048 = nn.Linear(4096, 2048) # N exactly 2048
52+
assert _auto_filter_for_rowwise(module_n_2048, "layer", []) is False
53+
54+
module_k_1024 = nn.Linear(1024, 4112) # K exactly 1024
55+
assert _auto_filter_for_rowwise(module_k_1024, "layer", []) is False
56+
57+
58+
def test_exact_boundary_dimensions_tensorwise():
59+
"""Test exact boundary dimensions for tensorwise filtering."""
60+
# Test exact combined threshold
61+
module_boundary = nn.Linear(4096, 1024) # K=4096, N=1024
62+
assert _auto_filter_for_tensorwise(module_boundary, "layer", []) is False
63+
64+
65+
def test_partial_fqn_matching():
66+
"""Test partial FQN matching behavior."""
67+
filter_fqns = ["embed", "norm"]
68+
large_module = nn.Linear(8192, 4096)
69+
70+
# (fqn, expected result from filter func)
71+
test_cases = [
72+
("model.embeddings.linear", False), # Contains "embed"
73+
("layer.norm.weight", False), # Contains "norm"
74+
("model.transformer.layer", True), # Doesn't contain either
75+
("embedding_layer", False), # Contains "embed" as substring
76+
]
77+
78+
for fqn, expected_result in test_cases:
79+
result_tensorwise = _auto_filter_for_tensorwise(large_module, fqn, filter_fqns)
80+
result_rowwise = _auto_filter_for_rowwise(large_module, fqn, filter_fqns)
81+
assert result_tensorwise is expected_result, (
82+
f"Tensorwise result mismatch: fqn={fqn}, expected={expected_result}, actual={result_tensorwise}"
83+
)
84+
assert result_rowwise is expected_result, (
85+
f"Rowwise result mismatch: fqn={fqn}, expected={expected_result}, actual={result_rowwise}"
86+
)

test/float8/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ IS_ROCM=$(rocm-smi --version || true)
1212
pytest test/float8/test_base.py
1313
pytest test/float8/test_compile.py
1414
pytest test/float8/test_numerics_integration.py
15+
pytest test/float8/test_auto_filter.py
1516

1617
# These tests do not work on ROCm yet
1718
if [ -z "$IS_ROCM" ]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchrun --nproc_per_node=2 -m pytest test/prototype/moe_training/test_tp.py

0 commit comments

Comments
 (0)