Skip to content

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Aug 6, 2025

Summary:

A lightweight logging framework for Float8Linear.

two step usage:

  1. set the _debug_logging_fn function on Float8LinearConfig to a user defined function, with the inputs being the fqn, the gemm name, and a_hp, b_hp, a_fp8 and b_fp8. This is generic allows that it allows for logging comparisons between pairs of gemm inputs, comparisons between gemm outputs, performance of gemms, etc.
  2. after model is converted to float8, call _populate_debug_fqns to populate debug FQN names

if 1 and 2 are done, then every forward and backward _debug_logging_fn will be called for each of the 3 gemms in fwd/bwd.

example usage:

        x = torch.randn(1, 16, 16, device="cuda", dtype=torch.bfloat16)
        m = nn.Sequential(
            nn.Linear(16, 32, bias=False, device="cuda", dtype=torch.bfloat16),
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(32, 64, bias=False, device="cuda", dtype=torch.bfloat16),
            ),
        )
        config = Float8LinearConfig.from_recipe_name(recipe_name)

        @torch.no_grad()
        def mean_absolute_percentage_error(x_ref, x):
            tmp = torch.abs(x_ref - x) / torch.clamp(torch.abs(x_ref), min=1e-9)
            # trim to avoid values close to 0 from
            # significantly impacting the results
            tmp = torch.clamp(tmp, max=1e3)
            return torch.mean(tmp)

        iter_counter = 0
        iter_fqn_gemm_name_to_data = {}

        @torch._dynamo.disable
        def debug_logging_fn(fqn, gemm_name, a_hp, b_hp, a_fp8, b_fp8):
            """
            Example debugging function - this is user defined, easy to customize
            1. captures M, K, N
            2. captures MAPE for high precision vs float8 gemm
            4. leaves data on GPU, so the user can move it to CPU at their
               convenience
            """
            M, K = a_hp.shape
            K2, N = b_hp.shape
            assert K == K2
            res_hp = a_hp @ b_hp
            res_fp8 = a_fp8 @ b_fp8
            mape = mean_absolute_percentage_error(res_hp, res_fp8)
            iter_fqn_gemm_name_to_data[(iter_counter, fqn, gemm_name)] = (M, K, N), mape

        object.__setattr__(config, "_debug_logging_fn", debug_logging_fn)
        m = convert_to_float8_training(m, config=config)
        _populate_debug_fqns(m)

        # iter 0
        m = torch.compile(m)
        y = m(x)
        y.sum().backward()

        # iter 1
        iter_counter += 1
        m = torch.compile(m)
        y = m(x)
        y.sum().backward()

Test Plan:

> pytest test/float8/test_base.py -s -x -k test_debug_logging
...
test/float8/test_base.py fqn: 0, gemm_name: output, sqnr: 29.125                    
fqn: 1.1, gemm_name: output, sqnr: 28.5                                                                          
fqn: 1.1, gemm_name: grad_input, sqnr: 33.5                                                                      
fqn: 1.1, gemm_name: grad_weight, sqnr: 38.5                                                                     
fqn: 0, gemm_name: grad_input, sqnr: 22.5                                                                        
fqn: 0, gemm_name: grad_weight, sqnr: 23.875              

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Aug 6, 2025

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Aug 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2701

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures

As of commit 6a7c700 with merge base 2db4c76 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Aug 6, 2025
Summary:

A lightweight logging flag to log the SQNR between the float8 gemm
output and the bf16 gemm output.

Test Plan:

```bash

```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f419ea2
ghstack-comment-id: 3160382784
Pull Request resolved: #2701
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 6, 2025
@vkuzo vkuzo added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 6, 2025
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this in D79724877.

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Aug 6, 2025
Summary:

A lightweight logging flag to log the SQNR between the float8 gemm
output and the bf16 gemm output.

Test Plan:

```bash

```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: dce0fc9
ghstack-comment-id: 3160382784
Pull Request resolved: #2701
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this in D79724877.

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Aug 15, 2025
Summary:

A lightweight logging flag to log the SQNR between the float8 gemm
output and the bf16 gemm output.

Test Plan:

```bash

```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: da0f0f4
ghstack-comment-id: 3160382784
Pull Request resolved: #2701
@vkuzo vkuzo changed the title [not for land] debug accuracy logging for float8 training [not for land] debug logging for float8 training Aug 15, 2025
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this in D79724877.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants