Skip to content

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Overview

Greptile Summary

This PR adds a new environment variable NVTE_KEEP_BACKWARD_UNQUANTIZED that enables quantized forward pass with high-precision backward pass (dgrad & wgrad). When enabled, the forward pass uses FP8 quantization for inference-like performance, while the backward pass operates in high precision for better gradient quality.

Key changes:

  • Added FP8GlobalStateManager.keep_backward_unquantized() method that reads the env var and returns False when delayed scaling recipe is used (line 437 checks recipe.delayed())
  • Modified all linear modules to store high-precision copies of activations (ln_out_hp, act_out_hp) when the feature is enabled
  • Disabled columnwise quantization usage throughout the codebase when backward should remain unquantized
  • Disabled Userbuffers communication overlapping in backward pass when using high-precision backward
  • Used unquantized weights (weight instead of weight_fp8) for dgrad/wgrad GEMMs

Trade-offs:

  • Increased memory usage from storing both quantized and high-precision activation copies
  • Additional compute overhead from recomputing activations in high precision (e.g., layernorm_mlp.py:620)
  • Feature is disabled when delayed scaling recipe is used (documented incompatibility)

Confidence Score: 4/5

  • safe to merge with minor documentation improvements recommended
  • implementation is sound with proper safeguards against delayed scaling recipe conflicts, but increased memory usage and compute overhead should be monitored in production
  • transformer_engine/pytorch/module/layernorm_mlp.py requires attention due to activation recomputation overhead

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py added keep_backward_unquantized() method that reads NVTE_KEEP_BACKWARD_UNQUANTIZED env var and returns False when delayed scaling is used
transformer_engine/pytorch/module/linear.py implements high-precision backward pass by disabling FP8 quantization when keep_backward_unquantized is enabled, uses unquantized weights for dgrad/wgrad GEMMs
transformer_engine/pytorch/module/layernorm_linear.py stores both quantized (ln_out) and high-precision (ln_out_hp) layernorm outputs when keep_backward_unquantized is enabled, uses high-precision version for backward
transformer_engine/pytorch/module/layernorm_mlp.py stores high-precision copies of ln_out and act_out when keep_backward_unquantized is enabled, recomputes activation on line 620 to get high-precision version
transformer_engine/pytorch/ops/basic/basic_linear.py disables columnwise quantization usage when keep_backward_unquantized is True, saves unquantized tensors for backward pass instead of quantized ones

Sequence Diagram

sequenceDiagram
    participant User
    participant Env as Environment Variable
    participant FP8Manager as FP8GlobalStateManager
    participant Module as Linear/LayerNormLinear/MLP
    participant Forward as Forward Pass
    participant Backward as Backward Pass

    User->>Env: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
    User->>Module: Forward pass with FP8 enabled
    Module->>FP8Manager: keep_backward_unquantized()
    FP8Manager->>FP8Manager: Check recipe.delayed()
    alt Delayed Scaling Recipe
        FP8Manager-->>Module: Return False (ignore env var)
    else Other Recipes
        FP8Manager-->>Module: Return True
    end
    
    alt keep_backward_unquantized=True
        Module->>Forward: Quantize input to FP8
        Forward->>Forward: Compute with FP8
        Forward->>Forward: Save high-precision copy (ln_out_hp, act_out_hp)
        Forward->>Forward: Disable columnwise quantization
        Forward-->>Module: FP8 output
        
        Module->>Backward: Start backward pass
        Backward->>Backward: Use high-precision saved tensors
        Backward->>Backward: Compute dgrad/wgrad without FP8 quantization
        Backward->>Backward: Disable Userbuffers communication
        Backward-->>Module: High-precision gradients
    else keep_backward_unquantized=False
        Module->>Forward: Quantize input to FP8
        Forward->>Forward: Compute with FP8
        Forward->>Forward: Save FP8 quantized tensors
        Forward-->>Module: FP8 output
        
        Module->>Backward: Start backward pass
        Backward->>Backward: Use FP8 quantized tensors
        Backward->>Backward: Compute dgrad/wgrad with FP8
        Backward-->>Module: FP8 gradients
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

FP8GlobalStateManager.is_fp8_enabled()
and FP8GlobalStateManager.keep_backward_unquantized()
)
if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reduce_and_update_fp8_tensors stuff is for delayed scaling, in which case we can just ignore it

keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False)
use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized
use_quantized_bwd = use_fp8_bwd or ctx.debug
if keep_backward_unquantized:
Copy link
Collaborator

@zhongbozhu zhongbozhu Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be related?

edit: disabling user-buffer when mixing fp8 & bf16 in one layer makes sense here

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 241 to 243
ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) significantly increases memory usage

verify this memory overhead is acceptable for large models, especially during training

Comment on lines 423 to 425

# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recomputing activation (activation_func(fc1_out, None, **act_params)) adds compute overhead

consider whether this is necessary or if the quantized act_out could be reused in some cases

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +620 to +621
if keep_backward_unquantized and is_grad_enabled and fc1_out is not None:
act_out_hp = activation_func(fc1_out, None, **act_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recomputing activation_func(fc1_out, None, **act_params) adds compute overhead that could be significant for expensive activations like GELU

if fc1_out is already available (which it is on line 620), storing act_out in high precision during the initial computation would avoid redundant activation calculations

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part of logic looks a bit off for me too, since the activation function have already quantized the output in fusion, so act_out is already fp8 so we cannot save the high precision of that for backward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct. Looking at the flow:

  1. When fc1_out_quantizer is set, activation_func returns quantized FP8 output in act_out
  2. At line 620, act_out is already FP8, so there's nothing high-precision to save
  3. The recomputation on line 621 is the only way to get the high-precision activation values

The logic is actually correct - you need the recomputation because the original forward pass discarded the high-precision values when it quantized. The variable name act_out_hp makes this clearer (hp = high precision).

However, the overhead concern from my previous comment stands. A potential optimization would be to conditionally store high-precision activations during the initial forward pass when keep_backward_unquantized=True, rather than always quantizing and then recomputing. This would trade memory for compute - but that might be acceptable since this feature is opt-in via environment variable.

ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
ln_out_hp = ln_out if keep_backward_unquantized else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation

verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile


ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias):
if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we can remove this line since it's about delayed scaling?

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
if (
ctx.fp8
and not ctx.keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment with linear.py, this seems to be delayed scaling only, can revert/ignore

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
if (
ctx.fp8
and not ctx.keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
if (
use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we already disabled ub above, this should also be redundant?


# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

# or 2) doing the recomputation with checkpointing
backwards_needs_fc1_input = fc1_weight.requires_grad and (
(is_grad_enabled and not checkpoint) or is_recomputation
backwards_needs_fc1_input = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backwards_needs_fc1_input should be orthogonal with keep_backward_unquantized?

inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias
if (
ctx.fp8
and not ctx.keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same old delayed scaling case


keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False)
use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized
fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None
Copy link
Collaborator

@zhongbozhu zhongbozhu Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, this fp8_recipe_bwd shouldn't be needed and use_fp8_bwd flag is enough

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.fc2_grad_output_quantizer is not None:
if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit about grad_output_preprocess already skip quant

# Whether to set grad arg in general_gemm
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling():
Copy link
Collaborator

@zhongbozhu zhongbozhu Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use ctx.fp8_recipe should be better (no strong opinion about this one)

if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8)
ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since ub is already disabled, this part should also be redundant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants