Skip to content

Commit a8a684d

Browse files
Fix layernorm arg parsing (bigscience-workshop#281)
1 parent ef13d09 commit a8a684d

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

megatron/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def _add_network_size_args(parser):
628628
'around zero. This improves numerical stability.')
629629
group.add_argument('--disable-mem-efficient-ln', action='store_false',
630630
help='Disable the memory-efficient fused LayerNorm optimization '
631-
'introduced in https://github.com/NVIDIA/apex/pull/1715')
631+
'introduced in https://github.com/NVIDIA/apex/pull/1715', dest='mem_efficient_ln')
632632
group.add_argument('--apply-residual-connection-post-layernorm',
633633
action='store_true',
634634
help='If set, use original BERT residula connection '

megatron/model/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def __init__(self, config,
920920
no_persist_layer_norm=args.no_persist_layer_norm,
921921
sequence_parallel=config.sequence_parallel,
922922
apply_layernorm_1p=args.apply_layernorm_1p,
923-
mem_efficient_ln=not args.disable_mem_efficient_ln)
923+
mem_efficient_ln=args.mem_efficient_ln)
924924
else:
925925
self.input_layernorm = LayerNorm(
926926
config.hidden_size,
@@ -946,7 +946,7 @@ def __init__(self, config,
946946
no_persist_layer_norm=not config.persist_layer_norm,
947947
sequence_parallel=config.sequence_parallel,
948948
apply_layernorm_1p=args.apply_layernorm_1p,
949-
mem_efficient_ln=not args.disable_mem_efficient_ln)
949+
mem_efficient_ln=args.mem_efficient_ln)
950950
else:
951951
self.post_attention_layernorm = LayerNorm(
952952
config.hidden_size,
@@ -970,7 +970,7 @@ def __init__(self, config,
970970
no_persist_layer_norm=not config.persist_layer_norm,
971971
sequence_parallel=config.sequence_parallel,
972972
apply_layernorm_1p=args.apply_layernorm_1p,
973-
mem_efficient_ln=not args.disable_mem_efficient_ln)
973+
mem_efficient_ln=args.mem_efficient_ln)
974974
else:
975975
self.post_inter_attention_layernorm = MixedFusedRMSNorm(config.hidden_size, config.layernorm_epsilon)
976976

@@ -1730,7 +1730,7 @@ def build_layer(layer_number, n_e):
17301730
no_persist_layer_norm=args.no_persist_layer_norm,
17311731
sequence_parallel=config.sequence_parallel,
17321732
apply_layernorm_1p=args.apply_layernorm_1p,
1733-
mem_efficient_ln=not args.disable_mem_efficient_ln)
1733+
mem_efficient_ln=args.mem_efficient_ln)
17341734
else:
17351735
self.final_layernorm = LayerNorm(
17361736
config.hidden_size,

0 commit comments

Comments
 (0)