Skip to content

Conversation

mattteochen
Copy link
Collaborator

@mattteochen mattteochen commented Oct 1, 2025

What does this PR do?

Fixes #2569 #2570 #2571 #2572 #2573 #2574 #2575

To stabilise numeric accuracy in internal CI which are run on Amper+ GPUs, TF32 computations are disabled.
Computation chains could lead to absolute errors of order 1e-3 exceeding the PyTorch assertions defaults.

A quick comparison on outputs logits on NanoGPT:

import math
from collections.abc import Sequence
import torch
import thunder
from thunder import pytorch_executor, nvfuser_executor

# Use the NanoGPT reference used in tests
import thunder.tests.nanogpt_model as nanogpt_model

def set_tf32(enabled: bool):
    if not torch.cuda.is_available():
        return
    torch.backends.cuda.matmul.fp32_precision = 'ieee' if not enabled else 'tf32'

def seed_all(seed: int = 0):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

@torch.no_grad()
def capture_leaf_outputs(model: torch.nn.Module, x: torch.Tensor) -> dict[str, torch.Tensor]:
    outputs: dict[str, torch.Tensor] = {}
    hooks = []

    def is_leaf(m: torch.nn.Module) -> bool:
        return len(list(m.children())) == 0

    def make_hook(name: str):
        def hook(_, __, out):
            if torch.is_tensor(out):
                outputs[name] = out.detach().float().cpu()
        return hook

    for name, mod in model.named_modules():
        if name == "" or not is_leaf(mod):
            continue
        hooks.append(mod.register_forward_hook(make_hook(name)))

    _ = model(x)

    for h in hooks:
        h.remove()

    return outputs

def max_abs_rel(a, b) -> tuple[float, float]:
    if isinstance(a, Sequence) and isinstance(b, Sequence):
        if len(a) != len(b):
            raise ValueError("Tuples/lists must have the same length")
        max_abs = 0.0
        max_rel = 0.0
        for ai, bi in zip(a, b):
            abs_i, rel_i = max_abs_rel(ai, bi)
            max_abs = max(max_abs, abs_i)
            max_rel = max(max_rel, rel_i)
        return max_abs, max_rel
    elif torch.is_tensor(a) and torch.is_tensor(b):
        d = (a - b).abs()
        max_abs = d.max().item() if d.numel() > 0 else 0.0
        denom = b.abs().max().item() if b.numel() > 0 else 0.0
        max_rel = (d.max().item() / max(denom, 1e-12)) if d.numel() > 0 else 0.0
        return max_abs, max_rel
    elif a is None and b is None:
        return 0.0, 0.0
    else:
        raise TypeError(f"Inputs must be tensors or tuples/lists of tensors, got {a.__class__} and {b.__class__}")

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # Build a small-ish, deterministic NanoGPT
    # Dropout=0 to avoid RNG differences
    config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=6, n_head=6, n_embd=768)
    model = nanogpt_model.GPT(config).to(device=device, dtype=dtype)

    # Same input for both runs
    seed_all(0)
    x = torch.randint(low=0, high=255, size=(4, 64), device=device, dtype=torch.int64)

    # Pass 1: TF32 OFF (reference)
    set_tf32(False)
    outs_off = capture_leaf_outputs(model, x)
    logits_off = model(x)

    # Pass 2: TF32 ON (comparison)
    set_tf32(True)
    outs_on = capture_leaf_outputs(model, x)
    logits_on = model(x)

    # Compare layer-by-layer for common leaves
    rows = []
    common = [n for n in outs_off.keys() if n in outs_on]
    for name in common:
        a, b = outs_off[name], outs_on[name]
        if a.shape == b.shape:
            max_abs, max_rel = max_abs_rel(a, b)
            rows.append((name, a.shape, max_abs, max_rel))

    # Sort by max_abs descending
    rows.sort(key=lambda t: t[2], reverse=True)

    print(f"Device: {device}, TF32 available: {torch.cuda.is_available()}")
    print("Top 15 layer diffs (name, shape, max_abs, max_rel):")
    for name, shape, ma, mr in rows[:15]:
        print(f"  {name:40s} {tuple(shape)!s:20s} max_abs={ma:.6e}  max_rel={mr:.6e}")

    # Find first layer exceeding a significance threshold
    sig_atol, sig_rtol = 1e-5, 1.3e-6 
    first_significant = next(
        ((n, s, ma, mr) for (n, s, ma, mr) in rows if (ma > sig_atol + sig_rtol * 0.0) or (mr > sig_rtol)),
        None,
    )
    if first_significant:
        n, s, ma, mr = first_significant
        print(f"\nFirst significant divergence (> atol={sig_atol} or rtol={sig_rtol}):")
        print(f"  {n} {tuple(s)}  max_abs={ma:.6e}  max_rel={mr:.6e}")
    else:
        print("\nNo significant layer divergence under the chosen thresholds.")

    # Logits comparison
    la, lr = max_abs_rel(logits_off, logits_on)
    print(f"\nLogits diff: max_abs={la:.6e}  max_rel={lr:.6e}")

if __name__ == "__main__":
    main()

Output:

Device: cuda, TF32 available: True
Top 15 layer diffs (name, shape, max_abs, max_rel):
  transformer.h.5.ln_1                     (4, 64, 768)         max_abs=3.436685e-03  max_rel=7.723264e-04
  transformer.ln_f                         (4, 64, 768)         max_abs=3.271997e-03  max_rel=6.862586e-04
  transformer.h.3.ln_2                     (4, 64, 768)         max_abs=3.168583e-03  max_rel=7.227786e-04
  transformer.h.3.ln_1                     (4, 64, 768)         max_abs=3.159404e-03  max_rel=7.111780e-04
  transformer.h.5.ln_2                     (4, 64, 768)         max_abs=3.093570e-03  max_rel=7.103587e-04
  transformer.h.4.ln_1                     (4, 64, 768)         max_abs=2.937198e-03  max_rel=6.224071e-04
  transformer.h.4.ln_2                     (4, 64, 768)         max_abs=2.864003e-03  max_rel=6.365660e-04
  transformer.h.2.ln_1                     (4, 64, 768)         max_abs=2.733231e-03  max_rel=6.180884e-04
  transformer.h.2.ln_2                     (4, 64, 768)         max_abs=2.680779e-03  max_rel=5.860344e-04
  transformer.h.1.ln_1                     (4, 64, 768)         max_abs=2.627254e-03  max_rel=6.116366e-04
  lm_head                                  (4, 64, 50304)       max_abs=2.329409e-03  max_rel=7.654070e-04
  transformer.h.1.ln_2                     (4, 64, 768)         max_abs=2.221942e-03  max_rel=4.589552e-04
  transformer.h.5.attn.c_attn              (4, 64, 2304)        max_abs=2.156943e-03  max_rel=8.313069e-04
  transformer.h.4.mlp.c_fc                 (4, 64, 3072)        max_abs=2.060950e-03  max_rel=8.104145e-04
  transformer.h.2.mlp.c_fc                 (4, 64, 3072)        max_abs=2.057329e-03  max_rel=7.192114e-04

First significant divergence (> atol=1e-05 or rtol=1.3e-06):
  transformer.h.5.ln_1 (4, 64, 768)  max_abs=3.436685e-03  max_rel=7.723264e-04

Logits diff: max_abs=2.329409e-03  max_rel=7.654070e-04

Container: pjnl-20251001
Device: A100x2 / RTX6000

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an alternative is to set NVIDIA_TF32_OVERRIDE=0 in CI envs.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi enabled auto-merge (squash) October 2, 2025 10:24
@t-vi
Copy link
Collaborator

t-vi commented Oct 7, 2025

Tests passed manually, merging.

@t-vi t-vi disabled auto-merge October 7, 2025 13:37
@t-vi t-vi merged commit e30133a into Lightning-AI:main Oct 7, 2025
58 of 60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

test_jit_general.py::test_litgpt_variants[cuda-codellama2-like] fails
5 participants