Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,26 @@ def __new__(cls, config: TransformerConfig):

class TENorm:
"""A conditional wrapper to initialize an instance of
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input."""
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.

# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
Args:
config (TransformerConfig): Transformer config.

hidden_size (int): Transformer hidden dimension.

eps (float): Epsilon added to denominator, for numerical stability.

maybe_fuse_quantize (bool): Whether to fuse quantize. This only works when FP8/FP4 is
enabled, otherwise it's a no-op.
"""

def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
maybe_fuse_quantize: bool = False,
):
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. "
Expand Down Expand Up @@ -237,6 +253,18 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5)
else:
raise Exception("Only LayerNorm and RMSNorm are curently supported")

# Ideally, we should use `LayerNormLinear` if we want to fuse normalization and
# quantization. But there're some exceptions, e.g., in MLA, we're using separate
# input layernorm and `Linear` layers. So we provide this option to allow users to
# fuse quantization in LayerNorm/RMSNorm.
if maybe_fuse_quantize:
assert is_te_min_version("2.10.0"), (
"Only TE >=2.10.0 supports fusing quantization in LayerNorm/RMSNorm."
)
instance = te.pytorch.ops.Sequential(
instance, te.pytorch.ops.Quantize(forward=True, backward=False)
)

return instance


Expand Down
31 changes: 13 additions & 18 deletions megatron/core/fusions/fused_layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import importlib
import inspect
import numbers

Expand Down Expand Up @@ -31,32 +29,22 @@ class FusedLayerNorm(torch.nn.Module):
"""Layer Norm, fused into a single CUDA kernel.

Args:
config (TransformerConfig): Transformer config.

hidden_size (int): Transformer hidden dimension.

eps (float): Epsilon added to denominator, for numerical stability.

persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.

zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.

config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.

normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
maybe_fuse_quantize (bool): Whether to fuse quantize. Added for compatibility with
Transformer Engine.
"""

def __init__(
self,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
persist_layer_norm: bool = True,
zero_centered_gamma: bool = False,
normalization: str = "LayerNorm", # included to match TE interface
maybe_fuse_quantize: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -108,7 +96,8 @@ def __init__(
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = eps
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
# Parameters need to be initialized with torch.empty rather than torch.Tensor
# for correct device placement with nemo2.
self.weight = Parameter(torch.empty(*hidden_size))
self.bias = Parameter(torch.empty(*hidden_size))
self.reset_parameters()
Expand All @@ -120,6 +109,9 @@ def __init__(
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)

def reset_parameters(self):
"""
Reset the parameters of the layer norm.
"""

if self.zero_centered_gamma:
init.zeros_(self.weight)
Expand All @@ -129,6 +121,9 @@ def reset_parameters(self):
init.zeros_(self.bias)

def forward(self, input: Tensor) -> Tensor:
"""
Forward pass of the layer norm.
"""

weight = self.weight + 1 if self.zero_centered_gamma else self.weight

Expand Down
16 changes: 11 additions & 5 deletions megatron/core/transformer/torch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@ class WrappedTorchNorm:
"""
A conditional wrapper to initialize an instance of PyTorch's
`LayerNorm` or `RMSNorm` based on input

Args:
config (TransformerConfig): Transformer config.

hidden_size (int): Transformer hidden dimension.

eps (float): Epsilon added to denominator, for numerical stability.

maybe_fuse_quantize (bool): Whether to fuse quantize. Added for compatibility with
Transformer Engine.
"""

def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
# TODO: unused arguments.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/issues/223
persist_layer_norm: bool = False,
zero_centered_gamma: bool = False,
normalization: str = "LayerNorm",
maybe_fuse_quantize: bool = False,
):
assert (
not config.layernorm_zero_centered_gamma
Expand Down
1 change: 1 addition & 0 deletions megatron/core/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
maybe_fuse_quantize=self.config.fp8 or self.config.fp4,
)

attention_optional_kwargs = {}
Expand Down