3
3
4
4
import pytest
5
5
from dataclasses import dataclass
6
+ from contextlib import nullcontext
6
7
from unittest .mock import patch
7
8
8
9
@dataclass
@@ -19,32 +20,39 @@ class Config:
19
20
"eval" : 34 ,
20
21
},
21
22
"fake_llama" : {
22
- "train" : 48 ,
23
- "eval" : 34 ,
23
+ "train" : 50 ,
24
+ "eval" : 21 ,
24
25
}
25
26
}
26
27
27
28
@patch ('llama_recipes.finetuning.train' )
28
29
@patch ('llama_recipes.finetuning.AutoTokenizer' )
29
30
@patch ("llama_recipes.finetuning.AutoConfig.from_pretrained" )
31
+ @patch ("llama_recipes.finetuning.AutoProcessor" )
32
+ @patch ("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained" )
30
33
@patch ('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained' )
31
34
@patch ('llama_recipes.finetuning.optim.AdamW' )
32
35
@patch ('llama_recipes.finetuning.StepLR' )
33
36
def test_packing (
34
37
step_lr ,
35
38
optimizer ,
36
39
get_model ,
40
+ get_mmodel ,
41
+ processor ,
37
42
get_config ,
38
43
tokenizer ,
39
44
train ,
40
45
setup_tokenizer ,
46
+ setup_processor ,
41
47
llama_version ,
42
48
model_type ,
43
49
):
44
50
from llama_recipes .finetuning import main
45
51
46
52
setup_tokenizer (tokenizer )
53
+ setup_processor (processor )
47
54
get_model .return_value .get_input_embeddings .return_value .weight .shape = [32000 if "Llama-2" in llama_version else 128256 ]
55
+ get_mmodel .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
48
56
get_config .return_value = Config (model_type = model_type )
49
57
50
58
kwargs = {
@@ -56,48 +64,73 @@ def test_packing(
56
64
"batching_strategy" : "packing" ,
57
65
}
58
66
59
- main ( ** kwargs )
67
+ c = nullcontext () if model_type == "llama" else pytest . raises ( ValueError )
60
68
61
- assert train .call_count == 1
69
+ with c :
70
+ main (** kwargs )
71
+
72
+ if model_type == "llama" :
73
+ assert train .call_count == 1
62
74
63
- args , kwargs = train .call_args
64
- train_dataloader = args [1 ]
65
- eval_dataloader = args [2 ]
75
+ args , kwargs = train .call_args
76
+ train_dataloader = args [1 ]
77
+ eval_dataloader = args [2 ]
66
78
67
- assert len (train_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["train" ]
68
- # assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
69
- # print(f"{len(eval_dataloader)=}")
79
+ assert len (train_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["train" ]
80
+ assert len (eval_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["eval" ]
70
81
71
- # batch = next(iter(train_dataloader))
82
+ batch = next (iter (train_dataloader ))
72
83
73
- # assert "labels" in batch.keys()
74
- # assert "input_ids" in batch.keys()
75
- # assert "attention_mask" in batch.keys()
84
+ assert "labels" in batch .keys ()
85
+ assert "input_ids" in batch .keys ()
86
+ assert "attention_mask" in batch .keys ()
76
87
77
- # # assert batch["labels"][0].size(0) == 4096
78
- # # assert batch["input_ids"][0].size(0) == 4096
79
- # # assert batch["attention_mask"][0].size(0) == 4096
80
- # print(batch["labels"][0].size(0))
81
- # print(batch["input_ids"][0].size(0))
82
- # print(batch["attention_mask"][0].size(0))
83
-
88
+ assert batch ["labels" ][0 ].size (0 ) == 4096
89
+ assert batch ["input_ids" ][0 ].size (0 ) == 4096
90
+ assert batch ["attention_mask" ][0 ].size (0 ) == 4096
84
91
85
92
93
+ @patch ("llama_recipes.finetuning.torch.cuda.is_available" )
86
94
@patch ('llama_recipes.finetuning.train' )
87
95
@patch ('llama_recipes.finetuning.AutoTokenizer' )
96
+ @patch ("llama_recipes.finetuning.AutoConfig.from_pretrained" )
97
+ @patch ("llama_recipes.finetuning.AutoProcessor" )
98
+ @patch ("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained" )
88
99
@patch ('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained' )
89
100
@patch ('llama_recipes.finetuning.optim.AdamW' )
90
101
@patch ('llama_recipes.finetuning.StepLR' )
91
102
@patch ('llama_recipes.finetuning.setup' )
92
103
@patch ('llama_recipes.finetuning.FSDP' )
93
104
@patch ('llama_recipes.finetuning.torch.distributed.is_initialized' )
94
105
@patch ('llama_recipes.utils.config_utils.dist' )
95
- def test_distributed_packing (dist , is_initialized , fsdp , setup , step_lr , optimizer , get_model , tokenizer , train , setup_tokenizer , llama_version ):
106
+ def test_distributed_packing (
107
+ dist ,
108
+ is_initialized ,
109
+ fsdp ,
110
+ setup ,
111
+ step_lr ,
112
+ optimizer ,
113
+ get_model ,
114
+ get_mmodel ,
115
+ processor ,
116
+ get_config ,
117
+ tokenizer ,
118
+ train ,
119
+ cuda_is_available ,
120
+ setup_tokenizer ,
121
+ setup_processor ,
122
+ llama_version ,
123
+ model_type ,
124
+ ):
96
125
import os
97
126
from llama_recipes .finetuning import main
98
127
99
128
setup_tokenizer (tokenizer )
129
+ setup_processor (processor )
100
130
get_model .return_value .get_input_embeddings .return_value .weight .shape = [32000 if "Llama-2" in llama_version else 128256 ]
131
+ get_mmodel .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
132
+ get_config .return_value = Config (model_type = model_type )
133
+ cuda_is_available .return_value = False
101
134
102
135
rank = 1
103
136
os .environ ['LOCAL_RANK' ] = f'{ rank } '
@@ -120,13 +153,17 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
120
153
dist .get_rank .return_value = rank
121
154
dist .get_world_size .return_value = 2
122
155
123
- main (** kwargs )
156
+ c = nullcontext () if model_type == "llama" else pytest .raises (ValueError )
157
+
158
+ with c :
159
+ main (** kwargs )
124
160
125
- assert train .call_count == 1
161
+ if model_type == "llama" :
162
+ assert train .call_count == 1
126
163
127
- args , kwargs = train .call_args
128
- train_dataloader = args [1 ]
129
- eval_dataloader = args [2 ]
164
+ args , kwargs = train .call_args
165
+ train_dataloader = args [1 ]
166
+ eval_dataloader = args [2 ]
130
167
131
- assert len (train_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["train" ] // 2
132
- assert len (eval_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["eval" ] // 2
168
+ assert len (train_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["train" ] // 2
169
+ assert len (eval_dataloader ) == EXPECTED_SAMPLE_NUMBER [llama_version ]["eval" ] // 2
0 commit comments