@@ -919,7 +919,8 @@ def __init__(self, config,
919919 eps = config .layernorm_epsilon ,
920920 no_persist_layer_norm = args .no_persist_layer_norm ,
921921 sequence_parallel = config .sequence_parallel ,
922- apply_layernorm_1p = args .apply_layernorm_1p )
922+ apply_layernorm_1p = args .apply_layernorm_1p ,
923+ mem_efficient_ln = not args .disable_mem_efficient_ln )
923924 else :
924925 self .input_layernorm = LayerNorm (
925926 config .hidden_size ,
@@ -944,7 +945,8 @@ def __init__(self, config,
944945 eps = config .layernorm_epsilon ,
945946 no_persist_layer_norm = not config .persist_layer_norm ,
946947 sequence_parallel = config .sequence_parallel ,
947- apply_layernorm_1p = args .apply_layernorm_1p )
948+ apply_layernorm_1p = args .apply_layernorm_1p ,
949+ mem_efficient_ln = not args .disable_mem_efficient_ln )
948950 else :
949951 self .post_attention_layernorm = LayerNorm (
950952 config .hidden_size ,
@@ -967,7 +969,8 @@ def __init__(self, config,
967969 eps = config .layernorm_epsilon ,
968970 no_persist_layer_norm = not config .persist_layer_norm ,
969971 sequence_parallel = config .sequence_parallel ,
970- apply_layernorm_1p = args .apply_layernorm_1p )
972+ apply_layernorm_1p = args .apply_layernorm_1p ,
973+ mem_efficient_ln = not args .disable_mem_efficient_ln )
971974 else :
972975 self .post_inter_attention_layernorm = MixedFusedRMSNorm (config .hidden_size , config .layernorm_epsilon )
973976
@@ -1726,7 +1729,8 @@ def build_layer(layer_number, n_e):
17261729 eps = config .layernorm_epsilon ,
17271730 no_persist_layer_norm = args .no_persist_layer_norm ,
17281731 sequence_parallel = config .sequence_parallel ,
1729- apply_layernorm_1p = args .apply_layernorm_1p )
1732+ apply_layernorm_1p = args .apply_layernorm_1p ,
1733+ mem_efficient_ln = not args .disable_mem_efficient_ln )
17301734 else :
17311735 self .final_layernorm = LayerNorm (
17321736 config .hidden_size ,
0 commit comments