Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
continue
elif "weight" in name and p.requires_grad:
p.main_grad = torch.zeros_like(p)
p.grad_added_to_main_grad = False # Should be set to True after backward

use_fp8 = fp8_recipe is not None
with autocast(enabled=use_fp8, recipe=fp8_recipe):
Expand All @@ -203,13 +204,19 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
torch.cuda.synchronize()

failed_grads = []
failed_grad_added_flags = []
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
if not torch.count_nonzero(p.main_grad) > 0:
failed_grads.append(name)
if not getattr(p, "grad_added_to_main_grad", False):
failed_grad_added_flags.append(name)
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
assert (
len(failed_grad_added_flags) == 0
), f"grad_added_to_main_grad not set to True for {failed_grad_added_flags}."


def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Expand Down
71 changes: 34 additions & 37 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings
import weakref

import functools
import torch
Expand Down Expand Up @@ -239,23 +240,9 @@ def forward(
else:
inputmats = [None] * num_gemms

if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")

if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_objects = []
for weight in weights:
ctx.weight_objects.append(weight)

tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
*weights,
*biases,
)
ctx.save_for_backward(*tensors_to_save)
Expand All @@ -267,6 +254,13 @@ def forward(

ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
# Keep weakrefs to weights to preserve attributes like main_grad
# when we need to modify the weight python objects
ctx.origin_weight_refs = [weakref.ref(w) for w in weights]
# Save overwrite_main_grad flag now while we have access to weight objects
ctx.origin_weights_overwrite_main_grad = getattr(
weights[0], "overwrite_main_grad", False
)
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
Expand All @@ -277,8 +271,6 @@ def forward(
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
]
else:
ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device
ctx.output_quantizers = output_quantizers
ctx.m_splits = m_splits
Expand Down Expand Up @@ -315,19 +307,24 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
N = ctx.num_gemms
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects):
origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None

if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
biases = saved_tensors[2 * N : 3 * N]

# Restore from weakrefs to get original weight python objects
# (preserves attributes like main_grad, grad_added_to_main_grad, etc.)
# Only needed when fuse_wgrad_accumulation is enabled.
origin_weights = [None] * N
main_grads = [None] * N
if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad:
origin_weight_refs = ctx.origin_weight_refs
ctx.origin_weight_refs = None
origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs]
assert all(
w is not None for w in origin_weights
), "weight was removed while fuse_wgrad_accumulation=True"
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
for origin_weight, main_grad in zip(origin_weights, main_grads):
if main_grad is not None:
origin_weight.main_grad = main_grad

# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
Expand Down Expand Up @@ -464,7 +461,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
if not getattr(ctx, "origin_weights_overwrite_main_grad", False)
else False
),
)
Expand All @@ -482,7 +479,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
# Deallocate input tensor
clear_tensor_data(*inputmats)

def handle_custom_ddp_from_mcore(weight, wgrad):
def handle_custom_ddp_from_mcore(weight, main_grad, wgrad):
if ctx.weights_requires_grad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(
Expand All @@ -491,14 +488,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
list(main_grad.shape),
main_grad.dtype,
zero=True,
)
else:
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
list(main_grad.shape),
main_grad.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
Expand All @@ -507,8 +504,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
return wgrad

wgrad_list = [
handle_custom_ddp_from_mcore(weight, wgrad)
for weight, wgrad in zip(origin_weights, wgrad_list)
handle_custom_ddp_from_mcore(weight, main_grad, wgrad)
for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list)
]
else:
wgrad_list = [None] * ctx.num_gemms
Expand Down
59 changes: 31 additions & 28 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""LayerNormLinear API"""
import os
import warnings
import weakref
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
Expand Down Expand Up @@ -454,19 +455,10 @@ def forward(
ln_weight,
ln_bias,
)
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight

tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
weightmat,
weight,
bias,
ln_weight,
ln_out,
Expand All @@ -479,6 +471,13 @@ def forward(
ctx.requires_wgrad = weight.requires_grad
ctx.is_weight_param_quantized = is_weight_param_quantized
if fuse_wgrad_accumulation and weight.requires_grad:
# Keep weakref to weight to preserve attributes like main_grad
# when we need to modify the weight python object
ctx.origin_weight_ref = weakref.ref(weight)
# Save overwrite_main_grad flag now while we have access to weight object
ctx.origin_weight_overwrites_main_grad = getattr(
weight, "overwrite_main_grad", False
)
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
Expand Down Expand Up @@ -554,7 +553,6 @@ def backward(
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
weight,
origin_weight,
bias,
ln_weight,
ln_out,
Expand All @@ -566,12 +564,25 @@ def backward(
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
ctx.main_grad_func()
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
# Restore from weakref to get original weight python object
# (preserves attributes like main_grad, grad_added_to_main_grad, etc.)
# Only needed when fuse_wgrad_accumulation is enabled.
origin_weight = None
origin_weight_overwrites_main_grad = getattr(
ctx, "origin_weight_overwrites_main_grad", False
)
main_grad = None
if ctx.fuse_wgrad_accumulation and ctx.requires_wgrad:
origin_weight_ref = ctx.origin_weight_ref
ctx.origin_weight_ref = None
origin_weight = origin_weight_ref() if origin_weight_ref is not None else None
assert (
origin_weight is not None
), "weight was removed while fuse_wgrad_accumulation=True"
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = ctx.main_grad_func() if weight is not None else None
if main_grad is not None:
origin_weight.main_grad = main_grad

# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
Expand All @@ -587,14 +598,6 @@ def backward(
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")

# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad

# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
Expand Down Expand Up @@ -868,7 +871,7 @@ def backward(
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
if not origin_weight_overwrites_main_grad
else False
),
"layout": "NT",
Expand Down Expand Up @@ -1000,14 +1003,14 @@ def wgrad_gemm(
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
list(main_grad.shape),
main_grad.dtype,
zero=True,
)
else:
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
list(main_grad.shape),
main_grad.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
Expand Down
Loading
Loading