Skip to content

Commit b554b24

Browse files
authored
Fix/unit test 3.2 (meta-llama#726)
2 parents a8e9f4e + d9ca099 commit b554b24

File tree

13 files changed

+393
-265
lines changed

13 files changed

+393
-265
lines changed

.github/workflows/pytest_cpu_gha_runner.yaml

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
name: "[GHA][CPU] llama-recipes Pytest tests on CPU GitHub hosted runner."
22
on:
33
pull_request:
4-
branches:
4+
branches:
55
- 'main'
6-
paths:
7-
- 'src/llama-recipes/configs/*.py'
8-
- 'src/llama-recipes/utils/*.py'
9-
- 'src/llama-recipes/datasets/*.py'
10-
- 'src/llama-recipes/data/*.py'
11-
- 'src/llama-recipes/*.py'
126

13-
# triggers workflow manually for debugging purposes.
7+
# triggers workflow manually for debugging purposes.
148
workflow_dispatch:
159
inputs:
1610
runner:
@@ -23,8 +17,8 @@ on:
2317
required: false
2418
default: "true"
2519

26-
env:
27-
PYTORCH_WHEEL_URL: https://download.pytorch.org/whl/test/cu118
20+
env:
21+
PYTORCH_WHEEL_URL: https://download.pytorch.org/whl/test/cu118
2822

2923
jobs:
3024
execute_workflow:
@@ -63,19 +57,18 @@ jobs:
6357
id: install_llama_recipes_package
6458
run: |
6559
echo "Installing 'llama-recipes' project (re: https://github.com/facebookresearch/llama-recipes?tab=readme-ov-file#install-with-optional-dependencies)"
66-
pip install --extra-index-url ${PYTORCH_WHEEL_URL} -e '.[tests]'
60+
pip install --extra-index-url ${PYTORCH_WHEEL_URL} -e '.[tests]'
6761
6862
6963
- name: "Running PyTest tests on GHA CPU Runner"
7064
id: pytest
7165
run: |
7266
echo "Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE}"
7367
cd $GITHUB_WORKSPACE && python3 -m pytest --junitxml="$GITHUB_WORKSPACE/result.xml"
74-
68+
7569
- name: Publish Test Summary
7670
id: test_summary
7771
uses: test-summary/action@v2
7872
with:
7973
paths: "**/*.xml"
8074
if: always()
81-

src/llama_recipes/configs/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ class samsum_dataset:
99
dataset: str = "samsum_dataset"
1010
train_split: str = "train"
1111
test_split: str = "validation"
12-
trust_remote_code: bool = False
1312

1413

1514
@dataclass

src/llama_recipes/datasets/samsum_dataset.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
import copy
77
import datasets
88

9+
from unittest.mock import patch
10+
11+
@patch('builtins.input', return_value="N")
12+
def load_samsum(split, _):
13+
try:
14+
ds = datasets.load_dataset("Samsung/samsum", split=split)
15+
except ValueError as e:
16+
if "trust_remote_code" in str(e):
17+
raise ValueError("Loading Samsung/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
18+
else:
19+
raise e
20+
return ds
21+
922

1023
def get_preprocessed_samsum(dataset_config, tokenizer, split):
11-
if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
12-
raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
13-
dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
24+
dataset = load_samsum(split)
1425

1526
prompt = (
1627
f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"

src/llama_recipes/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def main(**kwargs):
289289
)
290290
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
291291
if len(eval_dataloader) == 0:
292-
raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
292+
raise ValueError(f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})")
293293
else:
294294
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
295295

src/tests/conftest.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,27 @@
33

44
import pytest
55

6-
from transformers import AutoTokenizer
6+
from utils import maybe_tokenizer
77

8-
ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
9-
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
8+
ACCESS_ERROR_MSG = "Could not access tokenizer. Did you log into huggingface hub and provided the correct token?"
9+
10+
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct", "fake_llama"]
11+
12+
LLAMA_TOKENIZERS = {k: maybe_tokenizer(k) for k in LLAMA_VERSIONS}
1013

1114
@pytest.fixture(params=LLAMA_VERSIONS)
1215
def llama_version(request):
1316
return request.param
1417

1518

19+
@pytest.fixture(params=["mllama", "llama"])
20+
def model_type(request):
21+
return request.param
22+
23+
1624
@pytest.fixture(scope="module")
1725
def llama_tokenizer(request):
18-
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
26+
return LLAMA_TOKENIZERS
1927

2028

2129
@pytest.fixture
@@ -26,6 +34,13 @@ def _helper(tokenizer_mock):
2634

2735
return _helper
2836

37+
@pytest.fixture
38+
def setup_processor(llama_tokenizer, llama_version):
39+
def _helper(processor_mock):
40+
processor_mock.from_pretrained.return_value.tokenizer = llama_tokenizer[llama_version]
41+
42+
return _helper
43+
2944

3045
def pytest_addoption(parser):
3146
parser.addoption(
@@ -38,16 +53,18 @@ def pytest_configure(config):
3853

3954

4055
def pytest_collection_modifyitems(config, items):
56+
#skip tests marked with skip_missing_tokenizer if tokenizer is unavailable unless --unskip-missing-tokenizer is passed
4157
if config.getoption("--unskip-missing-tokenizer"):
4258
return
4359

44-
try:
45-
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
46-
tokenizer_available = True
47-
except OSError:
48-
tokenizer_available = False
49-
5060
skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
5161
for item in items:
52-
if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
62+
# get the tokenizer for the test
63+
version = [v for v in LLAMA_VERSIONS for i in item.keywords if v in i]
64+
if len(version) == 0:
65+
# no tokenizer used in this test
66+
continue
67+
version = version.pop()
68+
assert version in LLAMA_TOKENIZERS
69+
if "skip_missing_tokenizer" in item.keywords and LLAMA_TOKENIZERS[version] is None:
5370
item.add_marker(skip_missing_tokenizer)

src/tests/datasets/test_custom_dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import pytest
5+
from contextlib import nullcontext
56
from unittest.mock import patch
67

78
from transformers import LlamaTokenizer
@@ -96,15 +97,17 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
9697

9798

9899
@patch('llama_recipes.finetuning.train')
100+
@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
99101
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
100102
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
101103
@patch('llama_recipes.finetuning.optim.AdamW')
102104
@patch('llama_recipes.finetuning.StepLR')
103-
def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker, llama_version):
105+
def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, get_config, train, mocker, llama_version):
104106
from llama_recipes.finetuning import main
105107

106108
tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
107109
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
110+
get_config.return_value.model_type = "llama"
108111

109112
kwargs = {
110113
"dataset": "custom_dataset",
@@ -131,13 +134,16 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
131134
{"role":"assistant", "content":"Romans"},
132135
]
133136

134-
result = tokenize_dialog(dialog, tokenizer)
137+
c = pytest.raises(AttributeError) if llama_version == "fake_llama" else nullcontext()
138+
139+
with c:
140+
result = tokenize_dialog(dialog, tokenizer)
135141

136142
if "Llama-2" in llama_version:
137143
assert result["labels"][:12] == [-100] * 12
138144
assert result["labels"][17:28] == [-100] * 11
139145
assert result["labels"].count(-100) == 11 + 12
140-
else:
146+
elif "Llama-3" in llama_version:
141147
assert result["labels"][:38] == [-100] * 38
142148
assert result["labels"][43:54] == [-100] * 11
143149
assert result["labels"].count(-100) == 38 + 11

src/tests/datasets/test_grammar_datasets.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,27 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4+
from pathlib import Path
45
import pytest
56
from unittest.mock import patch
67

7-
8-
EXPECTED_RESULTS = {
9-
"meta-llama/Llama-2-7b-hf":{
10-
"label": 1152,
11-
"pos": 31,
12-
},
13-
"meta-llama/Meta-Llama-3.1-8B":{
14-
"label": 40,
15-
"pos": 26,
16-
},
17-
}
8+
DATA_DIR = Path(__file__).parents[2] / "llama_recipes/datasets/grammar_dataset/"
189

1910
@pytest.mark.skip_missing_tokenizer
11+
@pytest.mark.skipif(not Path(DATA_DIR / "grammar_validation.csv").exists(), reason="grammar_validation.csv not found")
12+
@pytest.mark.skipif(not Path(DATA_DIR / "gtrain_10k.csv").exists(), reason="gtrain_10k.csv not found")
2013
@patch('llama_recipes.finetuning.train')
2114
@patch('llama_recipes.finetuning.AutoTokenizer')
15+
@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
2216
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
2317
@patch('llama_recipes.finetuning.optim.AdamW')
2418
@patch('llama_recipes.finetuning.StepLR')
25-
def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
19+
def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):
2620
from llama_recipes.finetuning import main
2721

2822
setup_tokenizer(tokenizer)
2923
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
24+
get_config.return_value.model_type = "llama"
3025

3126
BATCH_SIZE = 8
3227
kwargs = {
@@ -58,9 +53,6 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_
5853
assert "input_ids" in batch.keys()
5954
assert "attention_mask" in batch.keys()
6055

61-
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
62-
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
63-
6456
token = args[3]
6557
assert batch["input_ids"][0][0] == token.bos_token_id
6658
assert batch["labels"][0][-1] == token.eos_token_id

src/tests/datasets/test_samsum_datasets.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,50 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import pytest
5+
from dataclasses import dataclass
56
from functools import partial
67
from unittest.mock import patch
8+
from datasets import load_dataset
79

8-
EXPECTED_RESULTS = {
9-
"meta-llama/Llama-2-7b-hf":{
10-
"label": 8432,
11-
"pos": 242,
12-
},
13-
"meta-llama/Meta-Llama-3.1-8B":{
14-
"label": 2250,
15-
"pos": 211,
16-
},
17-
}
10+
@dataclass
11+
class Config:
12+
model_type: str = "llama"
1813

14+
try:
15+
load_dataset("Samsung/samsum")
16+
SAMSUM_UNAVAILABLE = False
17+
except ValueError:
18+
SAMSUM_UNAVAILABLE = True
19+
20+
@pytest.mark.skipif(SAMSUM_UNAVAILABLE, reason="Samsum dataset is unavailable")
1921
@pytest.mark.skip_missing_tokenizer
2022
@patch('llama_recipes.finetuning.train')
2123
@patch('llama_recipes.finetuning.AutoTokenizer')
24+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
25+
@patch("llama_recipes.finetuning.AutoProcessor")
26+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
2227
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
2328
@patch('llama_recipes.finetuning.optim.AdamW')
2429
@patch('llama_recipes.finetuning.StepLR')
25-
def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
30+
def test_samsum_dataset(
31+
step_lr,
32+
optimizer,
33+
get_model,
34+
get_mmodel,
35+
processor,
36+
get_config,
37+
tokenizer,
38+
train,
39+
mocker,
40+
setup_tokenizer,
41+
llama_version,
42+
):
2643
from llama_recipes.finetuning import main
2744

2845
setup_tokenizer(tokenizer)
2946
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
47+
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
48+
get_config.return_value = Config()
3049

3150
BATCH_SIZE = 8
3251
kwargs = {
@@ -59,9 +78,6 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
5978
assert "input_ids" in batch.keys()
6079
assert "attention_mask" in batch.keys()
6180

62-
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
63-
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
64-
6581
assert batch["input_ids"][0][0] == token.bos_token_id
6682
assert batch["labels"][0][-1] == token.eos_token_id
6783
assert batch["input_ids"][0][-1] == token.eos_token_id

0 commit comments

Comments
 (0)