@@ -920,7 +920,7 @@ def __init__(self, config,
920920 no_persist_layer_norm = args .no_persist_layer_norm ,
921921 sequence_parallel = config .sequence_parallel ,
922922 apply_layernorm_1p = args .apply_layernorm_1p ,
923- mem_efficient_ln = not args .disable_mem_efficient_ln )
923+ mem_efficient_ln = args .mem_efficient_ln )
924924 else :
925925 self .input_layernorm = LayerNorm (
926926 config .hidden_size ,
@@ -946,7 +946,7 @@ def __init__(self, config,
946946 no_persist_layer_norm = not config .persist_layer_norm ,
947947 sequence_parallel = config .sequence_parallel ,
948948 apply_layernorm_1p = args .apply_layernorm_1p ,
949- mem_efficient_ln = not args .disable_mem_efficient_ln )
949+ mem_efficient_ln = args .mem_efficient_ln )
950950 else :
951951 self .post_attention_layernorm = LayerNorm (
952952 config .hidden_size ,
@@ -970,7 +970,7 @@ def __init__(self, config,
970970 no_persist_layer_norm = not config .persist_layer_norm ,
971971 sequence_parallel = config .sequence_parallel ,
972972 apply_layernorm_1p = args .apply_layernorm_1p ,
973- mem_efficient_ln = not args .disable_mem_efficient_ln )
973+ mem_efficient_ln = args .mem_efficient_ln )
974974 else :
975975 self .post_inter_attention_layernorm = MixedFusedRMSNorm (config .hidden_size , config .layernorm_epsilon )
976976
@@ -1730,7 +1730,7 @@ def build_layer(layer_number, n_e):
17301730 no_persist_layer_norm = args .no_persist_layer_norm ,
17311731 sequence_parallel = config .sequence_parallel ,
17321732 apply_layernorm_1p = args .apply_layernorm_1p ,
1733- mem_efficient_ln = not args .disable_mem_efficient_ln )
1733+ mem_efficient_ln = args .mem_efficient_ln )
17341734 else :
17351735 self .final_layernorm = LayerNorm (
17361736 config .hidden_size ,
0 commit comments