|
| 1 | +import pytest |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +from torchao.float8 import _auto_filter_for_recipe |
| 5 | +from torchao.float8.float8_linear_utils import ( |
| 6 | + _auto_filter_for_rowwise, |
| 7 | + _auto_filter_for_tensorwise, |
| 8 | +) |
| 9 | + |
| 10 | + |
| 11 | +@pytest.mark.parametrize( |
| 12 | + "recipe_type,module_dims,fqn,filter_fqns,expected", |
| 13 | + [ |
| 14 | + # Tensorwise tests |
| 15 | + ("tensorwise", (8192, 2048), "valid.layer", [], True), |
| 16 | + # FQN matches filter |
| 17 | + ("tensorwise", (8192, 2048), "skip_layer.linear", ["skip_layer"], False), |
| 18 | + # Threshold fail |
| 19 | + ("tensorwise", (4096, 1024), "valid.layer", [], False), |
| 20 | + # Rowwise tests |
| 21 | + ("rowwise", (4096, 8192), "valid.layer", [], True), |
| 22 | + ("rowwise", (4096, 8192), "skip_layer.linear", ["skip_layer"], False), |
| 23 | + # Combined threshold fail |
| 24 | + ( |
| 25 | + "rowwise", |
| 26 | + (2048, 4096), |
| 27 | + "valid.layer", |
| 28 | + [], |
| 29 | + False, |
| 30 | + ), |
| 31 | + ], |
| 32 | +) |
| 33 | +def test_end_to_end_filtering(recipe_type, module_dims, fqn, filter_fqns, expected): |
| 34 | + """Test complete filtering workflow for both recipe types.""" |
| 35 | + in_features, out_features = module_dims |
| 36 | + |
| 37 | + # Get the filter function |
| 38 | + filter_func = _auto_filter_for_recipe(recipe_type, filter_fqns) |
| 39 | + |
| 40 | + # Create test module |
| 41 | + test_module = nn.Linear(in_features, out_features) |
| 42 | + |
| 43 | + # Test filtering |
| 44 | + result = filter_func(test_module, fqn) |
| 45 | + assert result is expected |
| 46 | + |
| 47 | + |
| 48 | +def test_exact_boundary_dimensions_rowwise(): |
| 49 | + """Test exact boundary dimensions for rowwise filtering.""" |
| 50 | + # Test exact thresholds |
| 51 | + module_n_2048 = nn.Linear(4096, 2048) # N exactly 2048 |
| 52 | + assert _auto_filter_for_rowwise(module_n_2048, "layer", []) is False |
| 53 | + |
| 54 | + module_k_1024 = nn.Linear(1024, 4112) # K exactly 1024 |
| 55 | + assert _auto_filter_for_rowwise(module_k_1024, "layer", []) is False |
| 56 | + |
| 57 | + |
| 58 | +def test_exact_boundary_dimensions_tensorwise(): |
| 59 | + """Test exact boundary dimensions for tensorwise filtering.""" |
| 60 | + # Test exact combined threshold |
| 61 | + module_boundary = nn.Linear(4096, 1024) # K=4096, N=1024 |
| 62 | + assert _auto_filter_for_tensorwise(module_boundary, "layer", []) is False |
| 63 | + |
| 64 | + |
| 65 | +def test_partial_fqn_matching(): |
| 66 | + """Test partial FQN matching behavior.""" |
| 67 | + filter_fqns = ["embed", "norm"] |
| 68 | + large_module = nn.Linear(8192, 4096) |
| 69 | + |
| 70 | + # (fqn, expected result from filter func) |
| 71 | + test_cases = [ |
| 72 | + ("model.embeddings.linear", False), # Contains "embed" |
| 73 | + ("layer.norm.weight", False), # Contains "norm" |
| 74 | + ("model.transformer.layer", True), # Doesn't contain either |
| 75 | + ("embedding_layer", False), # Contains "embed" as substring |
| 76 | + ] |
| 77 | + |
| 78 | + for fqn, expected_result in test_cases: |
| 79 | + result_tensorwise = _auto_filter_for_tensorwise(large_module, fqn, filter_fqns) |
| 80 | + result_rowwise = _auto_filter_for_rowwise(large_module, fqn, filter_fqns) |
| 81 | + assert result_tensorwise is expected_result, ( |
| 82 | + f"Tensorwise result mismatch: fqn={fqn}, expected={expected_result}, actual={result_tensorwise}" |
| 83 | + ) |
| 84 | + assert result_rowwise is expected_result, ( |
| 85 | + f"Rowwise result mismatch: fqn={fqn}, expected={expected_result}, actual={result_rowwise}" |
| 86 | + ) |
0 commit comments