|
30 | 30 | stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] |
31 | 31 | start_step : 0 |
32 | 32 | end_step: 1 |
| 33 | +""", |
| 34 | + "log_fp8": """log_fp8: |
| 35 | + layers: |
| 36 | + layer_types: [linear] |
| 37 | + enabled: |
| 38 | + True |
| 39 | + transformer_engine: |
33 | 40 | LogFp8TensorStats: |
34 | 41 | enabled: True |
35 | 42 | tensors: [activation, gradient, weight] |
|
46 | 53 | FakeQuant: |
47 | 54 | enabled: True |
48 | 55 | gemms: [fprop, dgrad, wgrad] |
| 56 | + tensors: [activation, weight, gradient] |
49 | 57 | quant_format: FP8E5M2 |
50 | 58 | """, |
51 | 59 | } |
52 | 60 |
|
| 61 | +# Configs that require FP8 to be enabled |
| 62 | +fp8_required_configs = {"log_fp8"} |
| 63 | + |
53 | 64 |
|
54 | 65 | def _get_model(model_key): |
55 | 66 | if model_key == "linear": |
56 | | - return te.Linear(D, D) |
| 67 | + return te.Linear(D, D, name="layer") |
57 | 68 | if model_key == "layernorm_linear": |
58 | | - return te.LayerNormLinear(D, D) |
| 69 | + return te.LayerNormLinear(D, D, name="layer") |
59 | 70 | if model_key == "layernorm_mlp": |
60 | | - return te.LayerNormMLP(D, D, D) |
| 71 | + return te.LayerNormMLP(D, D, D, name="layer") |
61 | 72 | if model_key == "mha_attention": |
62 | | - return te.MultiheadAttention(D, H) |
| 73 | + return te.MultiheadAttention(D, H, name="layer") |
63 | 74 | if model_key == "transformer_layer": |
64 | | - return te.TransformerLayer(D, D, H) |
| 75 | + return te.TransformerLayer(D, D, H, name="layer") |
65 | 76 |
|
66 | 77 |
|
67 | 78 | def _run_forward_backward(model, fp8): |
@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): |
95 | 106 | def test_sanity_debug(model_key, fp8, config_key, feature_dirs): |
96 | 107 | if fp8 and not fp8_available: |
97 | 108 | pytest.skip(reason_for_no_fp8) |
| 109 | + if not fp8 and config_key in fp8_required_configs: |
| 110 | + pytest.skip(f"Config '{config_key}' requires FP8") |
98 | 111 | _run_test(model_key, fp8, configs[config_key], feature_dirs) |
0 commit comments