Skip to content

Commit 1de6ac3

Browse files
committed
mock samsum in test_batching
1 parent 26dff88 commit 1de6ac3

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/tests/test_batching.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import pytest
5-
from dataclasses import dataclass
65
from contextlib import nullcontext
6+
from dataclasses import dataclass
7+
from datasets import Dataset
78
from unittest.mock import patch
89

910
@dataclass
@@ -12,19 +13,23 @@ class Config:
1213

1314
EXPECTED_SAMPLE_NUMBER ={
1415
"meta-llama/Llama-2-7b-hf": {
15-
"train": 96,
16-
"eval": 42,
16+
"train": 4,
17+
"eval": 37,
1718
},
1819
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
19-
"train": 79,
20-
"eval": 34,
20+
"train": 3,
21+
"eval": 30,
2122
},
2223
"fake_llama": {
23-
"train": 50,
24-
"eval": 21,
24+
"train": 2,
25+
"eval": 17,
2526
}
2627
}
2728

29+
fake_samsum_dataset = 2048*[{'id': '420',
30+
'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.",
31+
'summary': 'Mario and Luigi are going to save the princess.'}]
32+
2833
@pytest.mark.skip_missing_tokenizer
2934
@patch('llama_recipes.finetuning.train')
3035
@patch('llama_recipes.finetuning.AutoTokenizer')
@@ -34,7 +39,9 @@ class Config:
3439
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
3540
@patch('llama_recipes.finetuning.optim.AdamW')
3641
@patch('llama_recipes.finetuning.StepLR')
42+
@patch('llama_recipes.datasets.samsum_dataset.datasets')
3743
def test_packing(
44+
datasets,
3845
step_lr,
3946
optimizer,
4047
get_model,
@@ -55,6 +62,8 @@ def test_packing(
5562
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
5663
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
5764
get_config.return_value = Config(model_type=model_type)
65+
66+
datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
5867

5968
kwargs = {
6069
"model_name": llama_version,
@@ -106,7 +115,9 @@ def test_packing(
106115
@patch('llama_recipes.finetuning.FSDP')
107116
@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
108117
@patch('llama_recipes.utils.config_utils.dist')
118+
@patch('llama_recipes.datasets.samsum_dataset.datasets')
109119
def test_distributed_packing(
120+
datasets,
110121
dist,
111122
is_initialized,
112123
fsdp,
@@ -137,6 +148,8 @@ def test_distributed_packing(
137148
cuda_is_available.return_value = False
138149
cuda_is_bf16_supported.return_value = False
139150

151+
datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
152+
140153
rank = 1
141154
os.environ['LOCAL_RANK'] = f'{rank}'
142155
os.environ['RANK'] = f'{rank}'

0 commit comments

Comments
 (0)