Skip to content

fix: disable tensor cores for FP32 backward GEMMs — NaN corruption (ALB-076) #170

@noahgift

Description

@noahgift

Problem

CUBLAS_GEMM_DEFAULT_TENSOR_OP (algorithm 99) produces ALL NaN output for transposed backward GEMMs (Trans/NoTrans, NoTrans/Trans) when input gradient magnitudes reach ~1e5. This occurs around block 18 of a 24-layer backward pass where gradient magnification grows from ~1e-5 to ~1e5.

Forward GEMMs (NoTrans/NoTrans) are unaffected.

Five Whys

  1. Why NaN weights? → optimizer reads NaN gradients
  2. Why NaN gradients? → cuBLAS backward_a/b output ALL NaN
  3. Why NaN output from valid inputs? → tensor core GEMM algorithm
  4. Why only backward? → backward uses Trans flag, forward doesn't
  5. Why only after ~5 blocks? → gradient magnification reaches ~1e5

Fix

Switch from tensor core math to SIMD math:

  • CUBLAS_TF32_TENSOR_OP_MATHCUBLAS_DEFAULT_MATH
  • CUBLAS_COMPUTE_32F_FAST_TF32CUBLAS_COMPUTE_32F
  • CUBLAS_GEMM_DEFAULT_TENSOR_OPCUBLAS_GEMM_DEFAULT

Performance

cuBLAS SIMD is still 6-14x faster than hand-written PTX:

  • PTX baseline: 890 tok/s, 2.6% MFU
  • cuBLAS SIMD: 5,216 tok/s, 15.1% MFU

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions