Skip to content

Commit c129020

Browse files
committed
Fix issues with fake_llama
1 parent fce8c26 commit c129020

File tree

6 files changed

+170
-74
lines changed

6 files changed

+170
-74
lines changed

src/tests/conftest.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33

44
import pytest
55

6-
from transformers import AutoTokenizer
6+
from utils import maybe_tokenizer
77

8-
ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
8+
ACCESS_ERROR_MSG = "Could not access tokenizer. Did you log into huggingface hub and provided the correct token?"
99

10-
try:
11-
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
12-
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
13-
except OSError:
14-
LLAMA_VERSIONS = ["fake_llama"]
10+
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct", "fake_llama"]
11+
12+
LLAMA_TOKENIZERS = {k: maybe_tokenizer(k) for k in LLAMA_VERSIONS}
1513

1614
@pytest.fixture(params=LLAMA_VERSIONS)
1715
def llama_version(request):
@@ -22,35 +20,10 @@ def llama_version(request):
2220
def model_type(request):
2321
return request.param
2422

25-
class FakeTokenier(object):
26-
def __init__(self):
27-
self.pad_token_id = 0
28-
self.bos_token_id = 1
29-
self.eos_token_id = 2
30-
self.sep_token_id = 3
31-
32-
self.pad_token = "<|pad_id|>"
33-
self.bos_token = "<|bos_id|>"
34-
self.eos_token = "<|eos_id|>"
35-
self.sep_token = "<|sep_id|>"
36-
37-
def __call__(self, *args, **kwargs):
38-
return self.encode(*args, **kwargs)
39-
40-
def encode(self, text, *args, **kwargs):
41-
breakpoint()
42-
return [len(c) for c in text.split(" ")]
43-
44-
def __len__(self):
45-
return 128256
46-
4723

4824
@pytest.fixture(scope="module")
4925
def llama_tokenizer(request):
50-
if LLAMA_VERSIONS == ["fake_llama"]:
51-
return {"fake_llama": FakeTokenier()}
52-
else:
53-
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
26+
return LLAMA_TOKENIZERS
5427

5528

5629
@pytest.fixture
@@ -61,6 +34,13 @@ def _helper(tokenizer_mock):
6134

6235
return _helper
6336

37+
@pytest.fixture
38+
def setup_processor(llama_tokenizer, llama_version):
39+
def _helper(processor_mock):
40+
processor_mock.from_pretrained.return_value.tokenizer = llama_tokenizer[llama_version]
41+
42+
return _helper
43+
6444

6545
def pytest_addoption(parser):
6646
parser.addoption(
@@ -73,16 +53,18 @@ def pytest_configure(config):
7353

7454

7555
def pytest_collection_modifyitems(config, items):
56+
#skip tests marked with skip_missing_tokenizer if tokenizer is unavailable unless --unskip-missing-tokenizer is passed
7657
if config.getoption("--unskip-missing-tokenizer"):
7758
return
7859

79-
try:
80-
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
81-
tokenizer_available = True
82-
except OSError:
83-
tokenizer_available = False
84-
8560
skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
8661
for item in items:
87-
if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
62+
# get the tokenizer for the test
63+
version = [v for v in LLAMA_VERSIONS for i in item.keywords if v in i]
64+
if len(version) == 0:
65+
# no tokenizer used in this test
66+
continue
67+
version = version.pop()
68+
assert version in LLAMA_TOKENIZERS
69+
if "skip_missing_tokenizer" in item.keywords and LLAMA_TOKENIZERS[version] is None:
8870
item.add_marker(skip_missing_tokenizer)

src/tests/datasets/test_custom_dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import pytest
5+
from contextlib import nullcontext
56
from unittest.mock import patch
67

78
from transformers import LlamaTokenizer
@@ -133,13 +134,16 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
133134
{"role":"assistant", "content":"Romans"},
134135
]
135136

136-
result = tokenize_dialog(dialog, tokenizer)
137+
c = pytest.raises(AttributeError) if llama_version == "fake_llama" else nullcontext()
138+
139+
with c:
140+
result = tokenize_dialog(dialog, tokenizer)
137141

138142
if "Llama-2" in llama_version:
139143
assert result["labels"][:12] == [-100] * 12
140144
assert result["labels"][17:28] == [-100] * 11
141145
assert result["labels"].count(-100) == 11 + 12
142-
else:
146+
elif "Llama-3" in llama_version:
143147
assert result["labels"][:38] == [-100] * 38
144148
assert result["labels"][43:54] == [-100] * 11
145149
assert result["labels"].count(-100) == 38 + 11

src/tests/datasets/test_samsum_datasets.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,42 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import pytest
5+
from dataclasses import dataclass
56
from functools import partial
67
from unittest.mock import patch
78

9+
@dataclass
10+
class Config:
11+
model_type: str = "llama"
12+
813
@pytest.mark.skip_missing_tokenizer
914
@patch('llama_recipes.finetuning.train')
1015
@patch('llama_recipes.finetuning.AutoTokenizer')
16+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
17+
@patch("llama_recipes.finetuning.AutoProcessor")
18+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
1119
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
1220
@patch('llama_recipes.finetuning.optim.AdamW')
1321
@patch('llama_recipes.finetuning.StepLR')
14-
def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
22+
def test_samsum_dataset(
23+
step_lr,
24+
optimizer,
25+
get_model,
26+
get_mmodel,
27+
processor,
28+
get_config,
29+
tokenizer,
30+
train,
31+
mocker,
32+
setup_tokenizer,
33+
llama_version,
34+
):
1535
from llama_recipes.finetuning import main
1636

1737
setup_tokenizer(tokenizer)
1838
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
39+
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
40+
get_config.return_value = Config()
1941

2042
BATCH_SIZE = 8
2143
kwargs = {

src/tests/test_batching.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from dataclasses import dataclass
6+
from contextlib import nullcontext
67
from unittest.mock import patch
78

89
@dataclass
@@ -19,32 +20,39 @@ class Config:
1920
"eval": 34,
2021
},
2122
"fake_llama": {
22-
"train": 48,
23-
"eval": 34,
23+
"train": 50,
24+
"eval": 21,
2425
}
2526
}
2627

2728
@patch('llama_recipes.finetuning.train')
2829
@patch('llama_recipes.finetuning.AutoTokenizer')
2930
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
31+
@patch("llama_recipes.finetuning.AutoProcessor")
32+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
3033
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
3134
@patch('llama_recipes.finetuning.optim.AdamW')
3235
@patch('llama_recipes.finetuning.StepLR')
3336
def test_packing(
3437
step_lr,
3538
optimizer,
3639
get_model,
40+
get_mmodel,
41+
processor,
3742
get_config,
3843
tokenizer,
3944
train,
4045
setup_tokenizer,
46+
setup_processor,
4147
llama_version,
4248
model_type,
4349
):
4450
from llama_recipes.finetuning import main
4551

4652
setup_tokenizer(tokenizer)
53+
setup_processor(processor)
4754
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
55+
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
4856
get_config.return_value = Config(model_type=model_type)
4957

5058
kwargs = {
@@ -56,48 +64,73 @@ def test_packing(
5664
"batching_strategy": "packing",
5765
}
5866

59-
main(**kwargs)
67+
c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
6068

61-
assert train.call_count == 1
69+
with c:
70+
main(**kwargs)
71+
72+
if model_type == "llama":
73+
assert train.call_count == 1
6274

63-
args, kwargs = train.call_args
64-
train_dataloader = args[1]
65-
eval_dataloader = args[2]
75+
args, kwargs = train.call_args
76+
train_dataloader = args[1]
77+
eval_dataloader = args[2]
6678

67-
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
68-
# assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
69-
# print(f"{len(eval_dataloader)=}")
79+
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
80+
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
7081

71-
# batch = next(iter(train_dataloader))
82+
batch = next(iter(train_dataloader))
7283

73-
# assert "labels" in batch.keys()
74-
# assert "input_ids" in batch.keys()
75-
# assert "attention_mask" in batch.keys()
84+
assert "labels" in batch.keys()
85+
assert "input_ids" in batch.keys()
86+
assert "attention_mask" in batch.keys()
7687

77-
# # assert batch["labels"][0].size(0) == 4096
78-
# # assert batch["input_ids"][0].size(0) == 4096
79-
# # assert batch["attention_mask"][0].size(0) == 4096
80-
# print(batch["labels"][0].size(0))
81-
# print(batch["input_ids"][0].size(0))
82-
# print(batch["attention_mask"][0].size(0))
83-
88+
assert batch["labels"][0].size(0) == 4096
89+
assert batch["input_ids"][0].size(0) == 4096
90+
assert batch["attention_mask"][0].size(0) == 4096
8491

8592

93+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
8694
@patch('llama_recipes.finetuning.train')
8795
@patch('llama_recipes.finetuning.AutoTokenizer')
96+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
97+
@patch("llama_recipes.finetuning.AutoProcessor")
98+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
8899
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
89100
@patch('llama_recipes.finetuning.optim.AdamW')
90101
@patch('llama_recipes.finetuning.StepLR')
91102
@patch('llama_recipes.finetuning.setup')
92103
@patch('llama_recipes.finetuning.FSDP')
93104
@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
94105
@patch('llama_recipes.utils.config_utils.dist')
95-
def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
106+
def test_distributed_packing(
107+
dist,
108+
is_initialized,
109+
fsdp,
110+
setup,
111+
step_lr,
112+
optimizer,
113+
get_model,
114+
get_mmodel,
115+
processor,
116+
get_config,
117+
tokenizer,
118+
train,
119+
cuda_is_available,
120+
setup_tokenizer,
121+
setup_processor,
122+
llama_version,
123+
model_type,
124+
):
96125
import os
97126
from llama_recipes.finetuning import main
98127

99128
setup_tokenizer(tokenizer)
129+
setup_processor(processor)
100130
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
131+
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
132+
get_config.return_value = Config(model_type=model_type)
133+
cuda_is_available.return_value = False
101134

102135
rank = 1
103136
os.environ['LOCAL_RANK'] = f'{rank}'
@@ -120,13 +153,17 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
120153
dist.get_rank.return_value = rank
121154
dist.get_world_size.return_value = 2
122155

123-
main(**kwargs)
156+
c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
157+
158+
with c:
159+
main(**kwargs)
124160

125-
assert train.call_count == 1
161+
if model_type == "llama":
162+
assert train.call_count == 1
126163

127-
args, kwargs = train.call_args
128-
train_dataloader = args[1]
129-
eval_dataloader = args[2]
164+
args, kwargs = train.call_args
165+
train_dataloader = args[1]
166+
eval_dataloader = args[2]
130167

131-
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
132-
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
168+
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
169+
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2

src/tests/test_chat_completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def _format_tokens_llama3(dialogs, tokenizer):
5252
def test_chat_completion(
5353
load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
5454
):
55-
if "Llama-2" in llama_version:
56-
pytest.skip("skipping test for Llama-2")
55+
if "Llama-2" in llama_version or llama_version == "fake_llama":
56+
pytest.skip(f"skipping test for {llama_version}")
5757

5858
from chat_completion import main
5959

0 commit comments

Comments
 (0)