44
55import deepspeed
66import torch
7+ from torch import nn
8+ import torch .nn .functional as F
9+
10+ from megatron .model .fused_layer_norm import MixedFusedLayerNorm
11+ from packaging import version
712
813from megatron import initialize_megatron , get_args , get_tokenizer , global_vars
9- from megatron .testing_utils import TestCasePlus , mockenv_context , flatten_arguments
14+ from megatron .testing_utils import TestCasePlus , mockenv_context , flatten_arguments , torch_assert_equal
1015from megatron .training import setup_model_and_optimizer
1116from pretrain_gpt import model_provider as gpt_model_provider , get_batch_pipe as get_gpt_batch_pipe
1217from pretrain_prefix_lm import model_provider as prefix_lm_model_provider , get_batch_pipe as get_prefix_lm_batch_pipe
@@ -51,9 +56,6 @@ def get_default_args():
5156 }
5257
5358
54-
55-
56-
5759def equal_vectors (tensor1 , tensor2 , dim = - 1 ):
5860 """View tensor1 and tensor2 as a list of vectors, and compute equality"""
5961 return torch .linalg .norm (tensor1 - tensor2 , dim = dim ) == 0
@@ -109,9 +111,7 @@ def test_gpt(self):
109111 output_changed = model (input_token_ids_changed , * input_batch [1 :])
110112
111113 # All token in past should be unchanged
112- self .assertTrue (
113- torch .all (equal_vectors (output [:, :changed_index ], output_changed [:, :changed_index ]))
114- )
114+ torch_assert_equal (output [:, :changed_index ], output_changed [:, :changed_index ])
115115 # All tokens in the future should have changed
116116 self .assertFalse (
117117 torch .any (equal_vectors (output [:, changed_index :], output_changed [:, changed_index :]))
@@ -173,23 +173,15 @@ def test_prefix_lm_reset_attention_mask(self):
173173 output_changed_target = model (token_ids_changed_target , * input_batch [1 :])
174174
175175 # All token in past should be unchanged
176- self .assertTrue (
177- torch .all (
178- equal_vectors (output [0 , :changed_target_index ], output_changed_target [0 , :changed_target_index ])
179- )
180- )
176+ torch_assert_equal (output [0 , :changed_target_index ], output_changed_target [0 , :changed_target_index ])
181177 # All tokens in the future should have changed
182178 self .assertFalse (
183179 torch .any (
184180 equal_vectors (output [0 , changed_target_index :], output_changed_target [0 , changed_target_index :])
185181 )
186182 )
187183 # Unchanged changed rows should not change either
188- self .assertTrue (
189- torch .all (
190- equal_vectors (output [1 , :], output_changed_target [1 , :])
191- )
192- )
184+ torch_assert_equal (output [1 , :], output_changed_target [1 , :])
193185
194186 ## --------------- CHANGE AN INPUT TOKEN ---------------------------
195187 # Let's change the the last prefix token and make sure that the first token changed
@@ -212,11 +204,7 @@ def test_prefix_lm_reset_attention_mask(self):
212204 )
213205 )
214206 # Unchanged changed rows should not change either
215- self .assertTrue (
216- torch .all (
217- equal_vectors (output [1 , :], output_changed_input [1 , :])
218- )
219- )
207+ torch_assert_equal (output [1 , :], output_changed_input [1 , :])
220208
221209 def test_prefix_lm_wo_reset_attention_mask (self ):
222210 """
@@ -282,6 +270,43 @@ def test_gpt_rotary_embeddings(self):
282270
283271 #TODO: Check all invariants
284272
273+ def test_fused_layer_norm (self ):
274+ command_args = get_default_args ()
275+
276+ # Condition to use custom cuda kernel
277+ command_args ["--bf16" ] = ""
278+ del command_args ["--fp16" ]
279+
280+ with patch ('sys.argv' , flatten_arguments (command_args )):
281+ with mockenv_context (** self .dist_env_1_gpu ):
282+ initialize_megatron ()
283+ args = get_args ()
284+
285+ dummy_input = torch .randn (args .micro_batch_size , args .seq_length , args .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
286+
287+ normalized_shape = (args .hidden_size ,)
288+ epsilon = 1e-5
289+ mfln = MixedFusedLayerNorm (normalized_shape , eps = epsilon )
290+
291+ self .assertTrue (mfln .use_meg_ds_fused_layer_norm , "Expected model to use Megatron-DeepSpeed custom cuda kernel for LayerNorm." )
292+ self .assertTrue (args .bf16 , "Test has to be done in half precision." )
293+
294+ # We set the weight manually so we simulate state that's not the initialisation
295+ weight = torch .randn (args .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
296+ bias = torch .randn (args .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
297+ mfln .weight = nn .Parameter (weight )
298+ mfln .bias = nn .Parameter (bias )
299+
300+ mfln_output = mfln (dummy_input )
301+ # We check that our layernorm matches pytorch 1.11 onwards
302+ if version .parse (torch .__version__ ) >= version .parse ("1.11.0" ):
303+ torch_layer_norm_output = F .layer_norm (dummy_input , normalized_shape , weight , bias , eps = epsilon )
304+ else :
305+ # In this case we use can check that basically it corresponds to the fp32 version
306+ torch_layer_norm_output = F .layer_norm (dummy_input .float (), normalized_shape , weight .float (), bias .float (), eps = epsilon ).to (torch .bfloat16 )
307+
308+ torch_assert_equal (mfln_output , torch_layer_norm_output )
309+
285310
286311if __name__ == '__main__' :
287312 unittest .main ()
0 commit comments