@@ -207,11 +207,11 @@ def test_against_original_open_llama_3b(device, dtype):
207207@pytest .mark .parametrize (
208208 "ours_kwargs" ,
209209 [
210- {"name" : "Llama-2-7b-hf" },
211- {"name" : "CodeLlama-7b-hf" },
212- {"name" : "Llama-2-70b-chat-hf" , "n_query_groups" : 1 },
213- {"name" : "Llama-3-8B" },
214- {"name" : "Llama-3-8B-Instruct" }
210+ {"name" : "Llama-2-7b-hf" },
211+ {"name" : "CodeLlama-7b-hf" },
212+ {"name" : "Llama-2-70b-chat-hf" , "n_query_groups" : 1 },
213+ {"name" : "Llama-3-8B" },
214+ {"name" : "Llama-3-8B-Instruct" },
215215 ],
216216)
217217@pytest .mark .parametrize (
@@ -267,6 +267,7 @@ def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype):
267267
268268
269269@torch .inference_mode ()
270+ @pytest .mark .parametrize ("model_name" , ("phi-1_5" , "phi-2" ))
270271@pytest .mark .parametrize (
271272 ("device" , "dtype" ),
272273 [
@@ -278,86 +279,14 @@ def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype):
278279 ),
279280 ],
280281)
281- def test_against_hf_phi_1_5 (device , dtype ):
282- wd = Path (__file__ ).parent .parent .resolve ()
283- workdir = wd / "tests" / "reference_models"
284- workdir .mkdir (parents = True , exist_ok = True )
285- file_paths = [workdir / "original_phi_1_5.py" , workdir / "configuration_phi.py" ]
286- urls = [
287- "https://huggingface.co/microsoft/phi-1_5/raw/main/modeling_phi.py" ,
288- "https://huggingface.co/microsoft/phi-1_5/raw/main/configuration_phi.py" ,
289- ]
290- for file_path , url in zip (file_paths , urls ):
291- if not file_path .is_file ():
292- urlretrieve (url = url , filename = file_path )
293-
294- from reference_models .configuration_phi import PhiConfig
295- from reference_models .original_phi_1_5 import PhiForCausalLM
282+ def test_against_hf_phi (model_name , device , dtype ):
283+ from transformers .models .phi .configuration_phi import PhiConfig
284+ from transformers .models .phi .modeling_phi import PhiForCausalLM
296285
297286 torch .set_default_dtype (dtype )
298287
299288 ours_config = Config .from_name (
300- "phi-1_5" , padded_vocab_size = 10000 , n_layer = 2 , n_head = 4 , n_embd = 256 , rotary_percentage = 0.5
301- )
302- T = 5
303- theirs_config = PhiConfig (
304- vocab_size = ours_config .padded_vocab_size ,
305- max_position_embeddings = ours_config .block_size ,
306- hidden_size = ours_config .n_embd ,
307- intermediate_size = ours_config .intermediate_size ,
308- num_attention_heads = ours_config .n_head ,
309- num_hidden_layers = ours_config .n_layer ,
310- partial_rotary_factor = ours_config .rotary_percentage ,
311- torch_dtype = dtype ,
312- )
313-
314- theirs_model = PhiForCausalLM (theirs_config ).to (device )
315- theirs_state_dict = theirs_model .state_dict ()
316- state_dict = {}
317- copy_weights_phi (ours_config , {}, state_dict , theirs_state_dict )
318- ours_model = GPT (ours_config ).to (device )
319- ours_model .load_state_dict (state_dict )
320-
321- # test end to end
322- x = torch .tensor ([[9856 , 23 , 491 , 1536 , 304 ]], dtype = torch .int32 , device = device )
323- assert x .size (1 ) == T
324- ours_y = ours_model (x )
325- theirs_y = theirs_model (x )["logits" ].to (dtype ) # HF converts logits to float
326- torch .testing .assert_close (ours_y , theirs_y )
327-
328-
329- @torch .inference_mode ()
330- @pytest .mark .parametrize (
331- ("device" , "dtype" ),
332- [
333- (torch .device ("cpu" ), torch .float32 ),
334- pytest .param (
335- torch .device ("cuda" ),
336- torch .float16 ,
337- marks = [pytest .mark .xfail (raises = AssertionError , strict = False ), RunIf (min_cuda_gpus = 1 )],
338- ),
339- ],
340- )
341- def test_against_hf_phi_2 (device , dtype ):
342- wd = Path (__file__ ).parent .parent .resolve ()
343- workdir = wd / "tests" / "reference_models"
344- workdir .mkdir (parents = True , exist_ok = True )
345- file_paths = [workdir / "original_phi_2.py" , workdir / "configuration_phi.py" ]
346- urls = [
347- "https://huggingface.co/microsoft/phi-2/raw/main/modeling_phi.py" ,
348- "https://huggingface.co/microsoft/phi-2/raw/main/configuration_phi.py" ,
349- ]
350- for file_path , url in zip (file_paths , urls ):
351- if not file_path .is_file ():
352- urlretrieve (url = url , filename = file_path )
353-
354- from reference_models .configuration_phi import PhiConfig
355- from reference_models .original_phi_2 import PhiForCausalLM
356-
357- torch .set_default_dtype (dtype )
358-
359- ours_config = Config .from_name (
360- "phi-2" , padded_vocab_size = 10000 , n_layer = 2 , n_head = 4 , n_embd = 256 , rotary_percentage = 0.5
289+ model_name , padded_vocab_size = 10000 , n_layer = 2 , n_head = 4 , n_embd = 256 , rotary_percentage = 0.5
361290 )
362291 T = 5
363292 theirs_config = PhiConfig (
0 commit comments