Skip to content

Commit ef13d09

Browse files
Introduce LayerNorm optimization from latest Apex (bigscience-workshop#277)
* Introduce LayerNorm optimization from NVIDIA/apex#1715 * Fix args call * Ad-hoc apex version check * Remove unnecessary TransformerConfig arg
1 parent a7b7cb7 commit ef13d09

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

megatron/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,9 @@ def _add_network_size_args(parser):
626626
group.add_argument('--apply-layernorm-1p', action='store_true',
627627
help='Adjust LayerNorm weights such that they are centered '
628628
'around zero. This improves numerical stability.')
629+
group.add_argument('--disable-mem-efficient-ln', action='store_false',
630+
help='Disable the memory-efficient fused LayerNorm optimization '
631+
'introduced in https://github.com/NVIDIA/apex/pull/1715')
629632
group.add_argument('--apply-residual-connection-post-layernorm',
630633
action='store_true',
631634
help='If set, use original BERT residula connection '

megatron/model/fused_layer_norm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.nn import init
1111
import importlib
1212
from torch.nn import functional as F
13+
import inspect
1314

1415
from megatron.core.utils import make_viewless_tensor
1516

@@ -31,10 +32,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
3132
def __init__(self, normalized_shape, eps=1e-5,
3233
no_persist_layer_norm=True,
3334
sequence_parallel=False,
34-
apply_layernorm_1p=False):
35+
apply_layernorm_1p=False,
36+
mem_efficient_ln=True):
3537
super(MixedFusedLayerNorm, self).__init__()
3638

3739
self.apply_layernorm_1p = apply_layernorm_1p
40+
self.mem_efficient_ln = mem_efficient_ln
3841

3942
global fused_layer_norm_cuda
4043
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
@@ -83,7 +86,12 @@ def forward(self, input):
8386
return F.layer_norm(input, self.normalized_shape, weight, self.bias, self.eps)
8487

8588
if self.no_persist_layer_norm:
86-
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
89+
# Apex does not have versions yet (https://github.com/NVIDIA/apex/pull/1648), so we need to inspect
90+
# the function manually on whether the extra arg introduced in https://github.com/NVIDIA/apex/pull/1715 exists yet
91+
if 'memory_efficient' in inspect.getfullargspec(FusedLayerNormAffineFunction.forward).args:
92+
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, self.mem_efficient_ln)
93+
else:
94+
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
8795
else:
8896
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
8997

megatron/model/transformer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,8 @@ def __init__(self, config,
919919
eps=config.layernorm_epsilon,
920920
no_persist_layer_norm=args.no_persist_layer_norm,
921921
sequence_parallel=config.sequence_parallel,
922-
apply_layernorm_1p=args.apply_layernorm_1p)
922+
apply_layernorm_1p=args.apply_layernorm_1p,
923+
mem_efficient_ln=not args.disable_mem_efficient_ln)
923924
else:
924925
self.input_layernorm = LayerNorm(
925926
config.hidden_size,
@@ -944,7 +945,8 @@ def __init__(self, config,
944945
eps=config.layernorm_epsilon,
945946
no_persist_layer_norm=not config.persist_layer_norm,
946947
sequence_parallel=config.sequence_parallel,
947-
apply_layernorm_1p=args.apply_layernorm_1p)
948+
apply_layernorm_1p=args.apply_layernorm_1p,
949+
mem_efficient_ln=not args.disable_mem_efficient_ln)
948950
else:
949951
self.post_attention_layernorm = LayerNorm(
950952
config.hidden_size,
@@ -967,7 +969,8 @@ def __init__(self, config,
967969
eps=config.layernorm_epsilon,
968970
no_persist_layer_norm=not config.persist_layer_norm,
969971
sequence_parallel=config.sequence_parallel,
970-
apply_layernorm_1p=args.apply_layernorm_1p)
972+
apply_layernorm_1p=args.apply_layernorm_1p,
973+
mem_efficient_ln=not args.disable_mem_efficient_ln)
971974
else:
972975
self.post_inter_attention_layernorm = MixedFusedRMSNorm(config.hidden_size, config.layernorm_epsilon)
973976

@@ -1726,7 +1729,8 @@ def build_layer(layer_number, n_e):
17261729
eps=config.layernorm_epsilon,
17271730
no_persist_layer_norm=args.no_persist_layer_norm,
17281731
sequence_parallel=config.sequence_parallel,
1729-
apply_layernorm_1p=args.apply_layernorm_1p)
1732+
apply_layernorm_1p=args.apply_layernorm_1p,
1733+
mem_efficient_ln=not args.disable_mem_efficient_ln)
17301734
else:
17311735
self.final_layernorm = LayerNorm(
17321736
config.hidden_size,

0 commit comments

Comments
 (0)