Skip to content

Commit b9883f4

Browse files
authored
switch to MixedFusedLayerNorm (#262)
1 parent dd06ea3 commit b9883f4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

megatron/mpu/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .utils import divide
3737
from .utils import split_tensor_along_last_dim
3838
from .utils import VocabUtility
39+
from ..model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
3940
from megatron import get_args, mpu
4041
import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing
4142

@@ -191,7 +192,7 @@ def __init__(self, num_embeddings, embedding_dim,
191192
# only the first stage embedding runs this class' forward. The head's embedding does its own
192193
# thing, so don't waste memory allocating LN weights.
193194
if mpu.is_pipeline_first_stage() and (args.use_bnb_optimizer or args.embed_layernorm):
194-
self.norm = torch.nn.LayerNorm(embedding_dim)
195+
self.norm = LayerNorm(embedding_dim)
195196

196197
if args.use_bnb_optimizer:
197198
# for BNB we ignore the passed init_method and use torch.nn.init.xavier_uniform_

0 commit comments

Comments
 (0)