Skip to content

Commit fce8c26

Browse files
committed
[WIP]add fake tokenizer
1 parent fac71cd commit fce8c26

File tree

2 files changed

+67
-13
lines changed

2 files changed

+67
-13
lines changed

src/tests/conftest.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from transformers import AutoTokenizer
77

88
ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
9-
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
9+
10+
try:
11+
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
12+
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
13+
except OSError:
14+
LLAMA_VERSIONS = ["fake_llama"]
1015

1116
@pytest.fixture(params=LLAMA_VERSIONS)
1217
def llama_version(request):
@@ -17,10 +22,35 @@ def llama_version(request):
1722
def model_type(request):
1823
return request.param
1924

25+
class FakeTokenier(object):
26+
def __init__(self):
27+
self.pad_token_id = 0
28+
self.bos_token_id = 1
29+
self.eos_token_id = 2
30+
self.sep_token_id = 3
31+
32+
self.pad_token = "<|pad_id|>"
33+
self.bos_token = "<|bos_id|>"
34+
self.eos_token = "<|eos_id|>"
35+
self.sep_token = "<|sep_id|>"
36+
37+
def __call__(self, *args, **kwargs):
38+
return self.encode(*args, **kwargs)
39+
40+
def encode(self, text, *args, **kwargs):
41+
breakpoint()
42+
return [len(c) for c in text.split(" ")]
43+
44+
def __len__(self):
45+
return 128256
46+
2047

2148
@pytest.fixture(scope="module")
2249
def llama_tokenizer(request):
23-
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
50+
if LLAMA_VERSIONS == ["fake_llama"]:
51+
return {"fake_llama": FakeTokenier()}
52+
else:
53+
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
2454

2555

2656
@pytest.fixture

src/tests/test_batching.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import pytest
5+
from dataclasses import dataclass
56
from unittest.mock import patch
67

8+
@dataclass
9+
class Config:
10+
model_type: str = "llama"
11+
712
EXPECTED_SAMPLE_NUMBER ={
813
"meta-llama/Llama-2-7b-hf": {
914
"train": 96,
@@ -12,20 +17,35 @@
1217
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
1318
"train": 79,
1419
"eval": 34,
20+
},
21+
"fake_llama": {
22+
"train": 48,
23+
"eval": 34,
1524
}
1625
}
1726

18-
@pytest.mark.skip_missing_tokenizer
1927
@patch('llama_recipes.finetuning.train')
2028
@patch('llama_recipes.finetuning.AutoTokenizer')
29+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
2130
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
2231
@patch('llama_recipes.finetuning.optim.AdamW')
2332
@patch('llama_recipes.finetuning.StepLR')
24-
def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
33+
def test_packing(
34+
step_lr,
35+
optimizer,
36+
get_model,
37+
get_config,
38+
tokenizer,
39+
train,
40+
setup_tokenizer,
41+
llama_version,
42+
model_type,
43+
):
2544
from llama_recipes.finetuning import main
2645

2746
setup_tokenizer(tokenizer)
2847
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
48+
get_config.return_value = Config(model_type=model_type)
2949

3050
kwargs = {
3151
"model_name": llama_version,
@@ -45,20 +65,24 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenize
4565
eval_dataloader = args[2]
4666

4767
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
48-
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
68+
# assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
69+
# print(f"{len(eval_dataloader)=}")
4970

50-
batch = next(iter(train_dataloader))
71+
# batch = next(iter(train_dataloader))
5172

52-
assert "labels" in batch.keys()
53-
assert "input_ids" in batch.keys()
54-
assert "attention_mask" in batch.keys()
73+
# assert "labels" in batch.keys()
74+
# assert "input_ids" in batch.keys()
75+
# assert "attention_mask" in batch.keys()
5576

56-
assert batch["labels"][0].size(0) == 4096
57-
assert batch["input_ids"][0].size(0) == 4096
58-
assert batch["attention_mask"][0].size(0) == 4096
77+
# # assert batch["labels"][0].size(0) == 4096
78+
# # assert batch["input_ids"][0].size(0) == 4096
79+
# # assert batch["attention_mask"][0].size(0) == 4096
80+
# print(batch["labels"][0].size(0))
81+
# print(batch["input_ids"][0].size(0))
82+
# print(batch["attention_mask"][0].size(0))
83+
5984

6085

61-
@pytest.mark.skip_missing_tokenizer
6286
@patch('llama_recipes.finetuning.train')
6387
@patch('llama_recipes.finetuning.AutoTokenizer')
6488
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')

0 commit comments

Comments
 (0)