2
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
3
4
4
import pytest
5
- from dataclasses import dataclass
6
5
from contextlib import nullcontext
6
+ from dataclasses import dataclass
7
+ from datasets import Dataset
7
8
from unittest .mock import patch
8
9
9
10
@dataclass
@@ -12,19 +13,23 @@ class Config:
12
13
13
14
EXPECTED_SAMPLE_NUMBER = {
14
15
"meta-llama/Llama-2-7b-hf" : {
15
- "train" : 96 ,
16
- "eval" : 42 ,
16
+ "train" : 4 ,
17
+ "eval" : 37 ,
17
18
},
18
19
"meta-llama/Meta-Llama-3.1-8B-Instruct" : {
19
- "train" : 79 ,
20
- "eval" : 34 ,
20
+ "train" : 3 ,
21
+ "eval" : 30 ,
21
22
},
22
23
"fake_llama" : {
23
- "train" : 50 ,
24
- "eval" : 21 ,
24
+ "train" : 2 ,
25
+ "eval" : 17 ,
25
26
}
26
27
}
27
28
29
+ fake_samsum_dataset = 2048 * [{'id' : '420' ,
30
+ 'dialogue' : "Mario: It's a me, Mario!\n Luigi: It's a me, your brother!\n Mario: I'm going to save the princess.\n Luigi: I'm going to help Mario." ,
31
+ 'summary' : 'Mario and Luigi are going to save the princess.' }]
32
+
28
33
@pytest .mark .skip_missing_tokenizer
29
34
@patch ('llama_recipes.finetuning.train' )
30
35
@patch ('llama_recipes.finetuning.AutoTokenizer' )
@@ -34,7 +39,9 @@ class Config:
34
39
@patch ('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained' )
35
40
@patch ('llama_recipes.finetuning.optim.AdamW' )
36
41
@patch ('llama_recipes.finetuning.StepLR' )
42
+ @patch ('llama_recipes.datasets.samsum_dataset.datasets' )
37
43
def test_packing (
44
+ datasets ,
38
45
step_lr ,
39
46
optimizer ,
40
47
get_model ,
@@ -55,6 +62,8 @@ def test_packing(
55
62
get_model .return_value .get_input_embeddings .return_value .weight .shape = [32000 if "Llama-2" in llama_version else 128256 ]
56
63
get_mmodel .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
57
64
get_config .return_value = Config (model_type = model_type )
65
+
66
+ datasets .load_dataset .return_value = Dataset .from_list (fake_samsum_dataset )
58
67
59
68
kwargs = {
60
69
"model_name" : llama_version ,
@@ -106,7 +115,9 @@ def test_packing(
106
115
@patch ('llama_recipes.finetuning.FSDP' )
107
116
@patch ('llama_recipes.finetuning.torch.distributed.is_initialized' )
108
117
@patch ('llama_recipes.utils.config_utils.dist' )
118
+ @patch ('llama_recipes.datasets.samsum_dataset.datasets' )
109
119
def test_distributed_packing (
120
+ datasets ,
110
121
dist ,
111
122
is_initialized ,
112
123
fsdp ,
@@ -137,6 +148,8 @@ def test_distributed_packing(
137
148
cuda_is_available .return_value = False
138
149
cuda_is_bf16_supported .return_value = False
139
150
151
+ datasets .load_dataset .return_value = Dataset .from_list (fake_samsum_dataset )
152
+
140
153
rank = 1
141
154
os .environ ['LOCAL_RANK' ] = f'{ rank } '
142
155
os .environ ['RANK' ] = f'{ rank } '
0 commit comments