Skip to content

Commit 5d96660

Browse files
Phi (tests): create a class directly from HF (#1343)
1 parent a91b520 commit 5d96660

File tree

3 files changed

+15
-148
lines changed

3 files changed

+15
-148
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,3 @@ checkpoints
1616
out
1717
wandb
1818
events.out.tfevents*
19-
20-
tests/reference_models

tests/test_convert_lit_checkpoint.py

Lines changed: 5 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -222,73 +222,13 @@ def test_against_original_open_llama_3b():
222222

223223

224224
@torch.inference_mode()
225-
def test_against_hf_phi_1_5():
226-
wd = Path(__file__).parent.parent.absolute()
227-
workdir = wd / "tests" / "reference_models"
228-
workdir.mkdir(parents=True, exist_ok=True)
229-
file_paths = [workdir / "original_phi_1_5.py", workdir / "configuration_phi.py"]
230-
urls = [
231-
"https://huggingface.co/microsoft/phi-1_5/raw/main/modeling_phi.py",
232-
"https://huggingface.co/microsoft/phi-1_5/raw/main/configuration_phi.py",
233-
]
234-
for file_path, url in zip(file_paths, urls):
235-
if not file_path.is_file():
236-
urlretrieve(url=url, filename=file_path)
237-
238-
from reference_models.configuration_phi import PhiConfig
239-
from reference_models.original_phi_1_5 import PhiForCausalLM
225+
@pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2"))
226+
def test_against_hf_phi(model_name):
227+
from transformers.models.phi.configuration_phi import PhiConfig
228+
from transformers.models.phi.modeling_phi import PhiForCausalLM
240229

241230
ours_config = Config.from_name(
242-
"phi-1_5", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
243-
)
244-
T = 5
245-
theirs_config = PhiConfig(
246-
vocab_size=ours_config.padded_vocab_size,
247-
max_position_embeddings=ours_config.block_size,
248-
hidden_size=ours_config.n_embd,
249-
intermediate_size=ours_config.intermediate_size,
250-
num_attention_heads=ours_config.n_head,
251-
num_hidden_layers=ours_config.n_layer,
252-
partial_rotary_factor=ours_config.rotary_percentage,
253-
)
254-
255-
ours_model = GPT(ours_config)
256-
ours_state_dict = ours_model.state_dict()
257-
theirs_state_dict = {}
258-
copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)
259-
theirs_model = PhiForCausalLM(theirs_config)
260-
# strict=False because we don't save the rotary embeddings inv frequency
261-
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
262-
assert not keys.unexpected_keys
263-
assert all("inv_freq" in k for k in keys.missing_keys)
264-
265-
# test end to end
266-
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
267-
assert x.size(1) == T
268-
ours_y = ours_model(x)
269-
theirs_y = theirs_model(x)["logits"]
270-
torch.testing.assert_close(ours_y, theirs_y)
271-
272-
273-
@torch.inference_mode()
274-
def test_against_hf_phi_2():
275-
wd = Path(__file__).parent.parent.absolute()
276-
workdir = wd / "tests" / "reference_models"
277-
workdir.mkdir(parents=True, exist_ok=True)
278-
file_paths = [workdir / "original_phi_2.py", workdir / "configuration_phi.py"]
279-
urls = [
280-
"https://huggingface.co/microsoft/phi-2/raw/main/modeling_phi.py",
281-
"https://huggingface.co/microsoft/phi-2/raw/main/configuration_phi.py",
282-
]
283-
for file_path, url in zip(file_paths, urls):
284-
if not file_path.is_file():
285-
urlretrieve(url=url, filename=file_path)
286-
287-
from reference_models.configuration_phi import PhiConfig
288-
from reference_models.original_phi_2 import PhiForCausalLM
289-
290-
ours_config = Config.from_name(
291-
"phi-2", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
231+
model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
292232
)
293233
T = 5
294234
theirs_config = PhiConfig(

tests/test_model.py

Lines changed: 10 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,11 @@ def test_against_original_open_llama_3b(device, dtype):
207207
@pytest.mark.parametrize(
208208
"ours_kwargs",
209209
[
210-
{"name": "Llama-2-7b-hf"},
211-
{"name": "CodeLlama-7b-hf"},
212-
{"name": "Llama-2-70b-chat-hf", "n_query_groups": 1},
213-
{"name": "Llama-3-8B"},
214-
{"name": "Llama-3-8B-Instruct"}
210+
{"name": "Llama-2-7b-hf"},
211+
{"name": "CodeLlama-7b-hf"},
212+
{"name": "Llama-2-70b-chat-hf", "n_query_groups": 1},
213+
{"name": "Llama-3-8B"},
214+
{"name": "Llama-3-8B-Instruct"},
215215
],
216216
)
217217
@pytest.mark.parametrize(
@@ -267,6 +267,7 @@ def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype):
267267

268268

269269
@torch.inference_mode()
270+
@pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2"))
270271
@pytest.mark.parametrize(
271272
("device", "dtype"),
272273
[
@@ -278,86 +279,14 @@ def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype):
278279
),
279280
],
280281
)
281-
def test_against_hf_phi_1_5(device, dtype):
282-
wd = Path(__file__).parent.parent.resolve()
283-
workdir = wd / "tests" / "reference_models"
284-
workdir.mkdir(parents=True, exist_ok=True)
285-
file_paths = [workdir / "original_phi_1_5.py", workdir / "configuration_phi.py"]
286-
urls = [
287-
"https://huggingface.co/microsoft/phi-1_5/raw/main/modeling_phi.py",
288-
"https://huggingface.co/microsoft/phi-1_5/raw/main/configuration_phi.py",
289-
]
290-
for file_path, url in zip(file_paths, urls):
291-
if not file_path.is_file():
292-
urlretrieve(url=url, filename=file_path)
293-
294-
from reference_models.configuration_phi import PhiConfig
295-
from reference_models.original_phi_1_5 import PhiForCausalLM
282+
def test_against_hf_phi(model_name, device, dtype):
283+
from transformers.models.phi.configuration_phi import PhiConfig
284+
from transformers.models.phi.modeling_phi import PhiForCausalLM
296285

297286
torch.set_default_dtype(dtype)
298287

299288
ours_config = Config.from_name(
300-
"phi-1_5", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
301-
)
302-
T = 5
303-
theirs_config = PhiConfig(
304-
vocab_size=ours_config.padded_vocab_size,
305-
max_position_embeddings=ours_config.block_size,
306-
hidden_size=ours_config.n_embd,
307-
intermediate_size=ours_config.intermediate_size,
308-
num_attention_heads=ours_config.n_head,
309-
num_hidden_layers=ours_config.n_layer,
310-
partial_rotary_factor=ours_config.rotary_percentage,
311-
torch_dtype=dtype,
312-
)
313-
314-
theirs_model = PhiForCausalLM(theirs_config).to(device)
315-
theirs_state_dict = theirs_model.state_dict()
316-
state_dict = {}
317-
copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)
318-
ours_model = GPT(ours_config).to(device)
319-
ours_model.load_state_dict(state_dict)
320-
321-
# test end to end
322-
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
323-
assert x.size(1) == T
324-
ours_y = ours_model(x)
325-
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
326-
torch.testing.assert_close(ours_y, theirs_y)
327-
328-
329-
@torch.inference_mode()
330-
@pytest.mark.parametrize(
331-
("device", "dtype"),
332-
[
333-
(torch.device("cpu"), torch.float32),
334-
pytest.param(
335-
torch.device("cuda"),
336-
torch.float16,
337-
marks=[pytest.mark.xfail(raises=AssertionError, strict=False), RunIf(min_cuda_gpus=1)],
338-
),
339-
],
340-
)
341-
def test_against_hf_phi_2(device, dtype):
342-
wd = Path(__file__).parent.parent.resolve()
343-
workdir = wd / "tests" / "reference_models"
344-
workdir.mkdir(parents=True, exist_ok=True)
345-
file_paths = [workdir / "original_phi_2.py", workdir / "configuration_phi.py"]
346-
urls = [
347-
"https://huggingface.co/microsoft/phi-2/raw/main/modeling_phi.py",
348-
"https://huggingface.co/microsoft/phi-2/raw/main/configuration_phi.py",
349-
]
350-
for file_path, url in zip(file_paths, urls):
351-
if not file_path.is_file():
352-
urlretrieve(url=url, filename=file_path)
353-
354-
from reference_models.configuration_phi import PhiConfig
355-
from reference_models.original_phi_2 import PhiForCausalLM
356-
357-
torch.set_default_dtype(dtype)
358-
359-
ours_config = Config.from_name(
360-
"phi-2", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
289+
model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
361290
)
362291
T = 5
363292
theirs_config = PhiConfig(

0 commit comments

Comments
 (0)