Skip to content

Commit dd06ea3

Browse files
authored
allocate embed norm only on pp0 (#261)
* allocate embed norm only on pp0 * text
1 parent 1cb76a6 commit dd06ea3

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

megatron/mpu/layers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .utils import divide
3737
from .utils import split_tensor_along_last_dim
3838
from .utils import VocabUtility
39-
from megatron import get_args
39+
from megatron import get_args, mpu
4040
import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing
4141

4242

@@ -188,7 +188,9 @@ def __init__(self, num_embeddings, embedding_dim,
188188
# Allocate weights and initialize.
189189
args = get_args()
190190

191-
if args.use_bnb_optimizer or args.embed_layernorm:
191+
# only the first stage embedding runs this class' forward. The head's embedding does its own
192+
# thing, so don't waste memory allocating LN weights.
193+
if mpu.is_pipeline_first_stage() and (args.use_bnb_optimizer or args.embed_layernorm):
192194
self.norm = torch.nn.LayerNorm(embedding_dim)
193195

194196
if args.use_bnb_optimizer:

0 commit comments

Comments
 (0)