Skip to content

feat: add float8_e4m3fnuz and float8_e5m2fnuz dtype support for AMD GPUs#711

Open
nathanrchn wants to merge 1 commit intohuggingface:mainfrom
nathanrchn:feat/float8-fnuz-dtypes
Open

feat: add float8_e4m3fnuz and float8_e5m2fnuz dtype support for AMD GPUs#711
nathanrchn wants to merge 1 commit intohuggingface:mainfrom
nathanrchn:feat/float8-fnuz-dtypes

Conversation

@nathanrchn
Copy link

@nathanrchn nathanrchn commented Feb 28, 2026

Add support for the FNUZ (finite, no unsigned zero) float8 variants used by PyTorch. These dtypes differ from the existing float8_e4m3fn and float8_e5m2 types by having no negative zero representation.

What does this PR do?

Summary

  • Add F8_E4M3FNUZ and F8_E5M2FNUZ variants to the Rust Dtype enum
  • Add dtype mappings in Python bindings (lib.rs, view.rs, torch.py)
  • Add round-trip test for both fnuz float8 dtypes

Motivation

AMD's CDNA3 architecture (MI300 series) natively uses the FNUZ (finite, no unsigned zero) variants of FP8 rather than the standard e4m3fn/e5m2 formats used by NVIDIA H100 GPUs. These dtypes (torch.float8_e4m3fnuz and torch.float8_e5m2fnuz) differ by having no negative zero — that bit pattern represents NaN instead.

Test plan

  • cargo test in safetensors/ crate passes
  • pytest bindings/python/tests/test_pt_comparison.py -k float8 passes (torch >= 2.1 with fnuz support)

Add support for the FNUZ (finite, no unsigned zero) float8 variants
used by PyTorch. These dtypes differ from the existing float8_e4m3fn
and float8_e5m2 types by having no negative zero representation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant