2
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
3
4
4
import os
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass
5
7
from unittest .mock import patch
6
8
7
9
import pytest
16
18
from torch .utils .data .sampler import BatchSampler
17
19
18
20
21
+ @dataclass
22
+ class Config :
23
+ model_type : str = "llama"
24
+
19
25
def get_fake_dataset ():
20
- return [
26
+ return 8192 * [
21
27
{
22
28
"input_ids" : [1 ],
23
29
"attention_mask" : [1 ],
@@ -28,28 +34,49 @@ def get_fake_dataset():
28
34
29
35
@patch ("llama_recipes.finetuning.torch.cuda.is_available" )
30
36
@patch ("llama_recipes.finetuning.train" )
37
+ @patch ("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained" )
38
+ @patch ("llama_recipes.finetuning.AutoProcessor.from_pretrained" )
31
39
@patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
40
+ @patch ("llama_recipes.finetuning.AutoConfig.from_pretrained" )
32
41
@patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
33
42
@patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
43
+ @patch ("llama_recipes.finetuning.generate_peft_config" )
44
+ @patch ("llama_recipes.finetuning.get_peft_model" )
34
45
@patch ("llama_recipes.finetuning.optim.AdamW" )
35
46
@patch ("llama_recipes.finetuning.StepLR" )
36
47
@pytest .mark .parametrize ("cuda_is_available" , [True , False ])
37
- def test_finetuning_no_validation (
48
+ @pytest .mark .parametrize ("run_validation" , [True , False ])
49
+ @pytest .mark .parametrize ("use_peft" , [True , False ])
50
+ def test_finetuning (
38
51
step_lr ,
39
52
optimizer ,
53
+ get_peft_model ,
54
+ gen_peft_config ,
40
55
get_dataset ,
41
56
tokenizer ,
57
+ get_config ,
42
58
get_model ,
59
+ get_processor ,
60
+ get_mmodel ,
43
61
train ,
44
62
cuda ,
45
63
cuda_is_available ,
64
+ run_validation ,
65
+ use_peft ,
66
+ model_type ,
46
67
):
47
- kwargs = {"run_validation" : False }
68
+ kwargs = {
69
+ "run_validation" : run_validation ,
70
+ "use_peft" : use_peft ,
71
+ "batching_strategy" : "packing" if model_type == "llama" else "padding" ,
72
+ }
48
73
49
74
get_dataset .return_value = get_fake_dataset ()
50
75
cuda .return_value = cuda_is_available
51
76
52
77
get_model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
78
+ get_mmodel .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
79
+ get_config .return_value = Config (model_type = model_type )
53
80
54
81
main (** kwargs )
55
82
@@ -60,115 +87,99 @@ def test_finetuning_no_validation(
60
87
eval_dataloader = args [2 ]
61
88
62
89
assert isinstance (train_dataloader , DataLoader )
63
- assert eval_dataloader is None
64
-
65
- if cuda_is_available :
66
- assert get_model .return_value .to .call_count == 1
67
- assert get_model .return_value .to .call_args .args [0 ] == "cuda"
90
+ if run_validation :
91
+ assert isinstance (eval_dataloader , DataLoader )
68
92
else :
69
- assert get_model .return_value .to .call_count == 0
70
-
71
-
72
- @patch ("llama_recipes.finetuning.torch.cuda.is_available" )
73
- @patch ("llama_recipes.finetuning.train" )
74
- @patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
75
- @patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
76
- @patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
77
- @patch ("llama_recipes.finetuning.optim.AdamW" )
78
- @patch ("llama_recipes.finetuning.StepLR" )
79
- @pytest .mark .parametrize ("cuda_is_available" , [True , False ])
80
- def test_finetuning_with_validation (
81
- step_lr ,
82
- optimizer ,
83
- get_dataset ,
84
- tokenizer ,
85
- get_model ,
86
- train ,
87
- cuda ,
88
- cuda_is_available ,
89
- ):
90
- kwargs = {"run_validation" : True }
91
-
92
- get_dataset .return_value = get_fake_dataset ()
93
- cuda .return_value = cuda_is_available
94
-
95
- get_model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
96
-
97
- main (** kwargs )
98
-
99
- assert train .call_count == 1
100
-
101
- args , kwargs = train .call_args
102
- train_dataloader = args [1 ]
103
- eval_dataloader = args [2 ]
104
- assert isinstance (train_dataloader , DataLoader )
105
- assert isinstance (eval_dataloader , DataLoader )
93
+ assert eval_dataloader is None
106
94
107
- if cuda_is_available :
108
- assert get_model .return_value .to .call_count == 1
109
- assert get_model .return_value .to .call_args .args [0 ] == "cuda"
95
+ if use_peft :
96
+ assert get_peft_model .return_value .print_trainable_parameters .call_count == 1
97
+ model = get_peft_model
98
+ elif model_type == "llama" :
99
+ model = get_model
110
100
else :
111
- assert get_model .return_value .to .call_count == 0
112
-
113
-
114
- @patch ("llama_recipes.finetuning.torch.cuda.is_available" )
115
- @patch ("llama_recipes.finetuning.train" )
116
- @patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
117
- @patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
118
- @patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
119
- @patch ("llama_recipes.finetuning.generate_peft_config" )
120
- @patch ("llama_recipes.finetuning.get_peft_model" )
121
- @patch ("llama_recipes.finetuning.optim.AdamW" )
122
- @patch ("llama_recipes.finetuning.StepLR" )
123
- @pytest .mark .parametrize ("cuda_is_available" , [True , False ])
124
- def test_finetuning_peft_lora (
125
- step_lr ,
126
- optimizer ,
127
- get_peft_model ,
128
- gen_peft_config ,
129
- get_dataset ,
130
- tokenizer ,
131
- get_model ,
132
- train ,
133
- cuda ,
134
- cuda_is_available ,
135
- ):
136
- kwargs = {"use_peft" : True }
137
-
138
- get_dataset .return_value = get_fake_dataset ()
139
- cuda .return_value = cuda_is_available
140
-
141
- get_model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
142
-
143
- main (** kwargs )
101
+ model = get_mmodel
144
102
145
103
if cuda_is_available :
146
- assert get_peft_model .return_value .to .call_count == 1
147
- assert get_peft_model .return_value .to .call_args .args [0 ] == "cuda"
104
+ assert model .return_value .to .call_count == 1
105
+ assert model .return_value .to .call_args .args [0 ] == "cuda"
148
106
else :
149
- assert get_peft_model .return_value .to .call_count == 0
150
-
151
- assert get_peft_model .return_value .print_trainable_parameters .call_count == 1
107
+ assert model .return_value .to .call_count == 0
108
+
109
+
110
+ # @patch("llama_recipes.finetuning.torch.cuda.is_available")
111
+ # @patch("llama_recipes.finetuning.train")
112
+ # @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
113
+ # @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
114
+ # @patch("llama_recipes.finetuning.get_preprocessed_dataset")
115
+ # @patch("llama_recipes.finetuning.generate_peft_config")
116
+ # @patch("llama_recipes.finetuning.get_peft_model")
117
+ # @patch("llama_recipes.finetuning.optim.AdamW")
118
+ # @patch("llama_recipes.finetuning.StepLR")
119
+ # @pytest.mark.parametrize("cuda_is_available", [True, False])
120
+ # def test_finetuning_peft_lora(
121
+ # step_lr,
122
+ # optimizer,
123
+ # get_peft_model,
124
+ # gen_peft_config,
125
+ # get_dataset,
126
+ # tokenizer,
127
+ # get_model,
128
+ # train,
129
+ # cuda,
130
+ # cuda_is_available,
131
+ # ):
132
+ # kwargs = {"use_peft": True}
133
+
134
+ # get_dataset.return_value = get_fake_dataset()
135
+ # cuda.return_value = cuda_is_available
136
+
137
+ # get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
138
+
139
+ # main(**kwargs)
140
+
141
+ # if cuda_is_available:
142
+ # assert get_peft_model.return_value.to.call_count == 1
143
+ # assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
144
+ # else:
145
+ # assert get_peft_model.return_value.to.call_count == 0
146
+
147
+
152
148
153
149
154
150
@patch ("llama_recipes.finetuning.get_peft_model" )
155
151
@patch ("llama_recipes.finetuning.setup" )
156
152
@patch ("llama_recipes.finetuning.train" )
153
+ @patch ("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained" )
154
+ @patch ("llama_recipes.finetuning.AutoProcessor.from_pretrained" )
157
155
@patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
156
+ @patch ("llama_recipes.finetuning.AutoConfig.from_pretrained" )
158
157
@patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
159
158
@patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
160
159
def test_finetuning_peft_llama_adapter (
161
- get_dataset , tokenizer , get_model , train , setup , get_peft_model
160
+ get_dataset ,
161
+ tokenizer ,
162
+ get_config ,
163
+ get_model ,
164
+ get_processor ,
165
+ get_mmodel ,
166
+ train ,
167
+ setup ,
168
+ get_peft_model ,
169
+ model_type ,
162
170
):
163
171
kwargs = {
164
172
"use_peft" : True ,
165
173
"peft_method" : "llama_adapter" ,
166
174
"enable_fsdp" : True ,
175
+ "batching_strategy" : "packing" if model_type == "llama" else "padding" ,
167
176
}
168
177
169
178
get_dataset .return_value = get_fake_dataset ()
170
179
171
180
get_model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
181
+ get_mmodel .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
182
+ get_config .return_value = Config (model_type = model_type )
172
183
173
184
os .environ ["RANK" ] = "0"
174
185
os .environ ["LOCAL_RANK" ] = "0"
@@ -195,20 +206,38 @@ def test_finetuning_peft_llama_adapter(
195
206
196
207
197
208
@patch ("llama_recipes.finetuning.train" )
209
+ @patch ("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained" )
210
+ @patch ("llama_recipes.finetuning.AutoProcessor.from_pretrained" )
198
211
@patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
212
+ @patch ("llama_recipes.finetuning.AutoConfig.from_pretrained" )
199
213
@patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
200
214
@patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
201
215
@patch ("llama_recipes.finetuning.get_peft_model" )
202
216
@patch ("llama_recipes.finetuning.StepLR" )
203
217
def test_finetuning_weight_decay (
204
- step_lr , get_peft_model , get_dataset , tokenizer , get_model , train
218
+ step_lr ,
219
+ get_peft_model ,
220
+ get_dataset ,
221
+ tokenizer ,
222
+ get_config ,
223
+ get_model ,
224
+ get_processor ,
225
+ get_mmodel ,
226
+ train ,
227
+ model_type ,
205
228
):
206
- kwargs = {"weight_decay" : 0.01 }
229
+ kwargs = {
230
+ "weight_decay" : 0.01 ,
231
+ "batching_strategy" : "packing" if model_type == "llama" else "padding" ,
232
+ }
207
233
208
234
get_dataset .return_value = get_fake_dataset ()
209
235
210
- get_model .return_value .parameters .return_value = [torch .ones (1 , 1 )]
211
- get_model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
236
+ model = get_model if model_type == "llama" else get_mmodel
237
+ model .return_value .parameters .return_value = [torch .ones (1 , 1 )]
238
+ model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
239
+
240
+ get_config .return_value = Config (model_type = model_type )
212
241
213
242
main (** kwargs )
214
243
@@ -224,28 +253,49 @@ def test_finetuning_weight_decay(
224
253
225
254
226
255
@patch ("llama_recipes.finetuning.train" )
256
+ @patch ("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained" )
257
+ @patch ("llama_recipes.finetuning.AutoProcessor.from_pretrained" )
227
258
@patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
259
+ @patch ("llama_recipes.finetuning.AutoConfig.from_pretrained" )
228
260
@patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
229
261
@patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
230
262
@patch ("llama_recipes.finetuning.optim.AdamW" )
231
263
@patch ("llama_recipes.finetuning.StepLR" )
232
264
def test_batching_strategy (
233
- step_lr , optimizer , get_dataset , tokenizer , get_model , train
265
+ step_lr ,
266
+ optimizer ,
267
+ get_dataset ,
268
+ tokenizer ,
269
+ get_config ,
270
+ get_model ,
271
+ get_processor ,
272
+ get_mmodel ,
273
+ train ,
274
+ model_type ,
234
275
):
235
- kwargs = {"batching_strategy" : "packing" }
276
+ kwargs = {
277
+ "batching_strategy" : "packing" ,
278
+ }
236
279
237
280
get_dataset .return_value = get_fake_dataset ()
238
281
239
- get_model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
282
+ model = get_model if model_type == "llama" else get_mmodel
283
+ model .return_value .get_input_embeddings .return_value .weight .shape = [0 ]
240
284
241
- main ( ** kwargs )
285
+ get_config . return_value = Config ( model_type = model_type )
242
286
243
- assert train .call_count == 1
287
+ c = nullcontext () if model_type == "llama" else pytest .raises (ValueError )
288
+
289
+ with c :
290
+ main (** kwargs )
244
291
245
- args , kwargs = train .call_args
246
- train_dataloader , eval_dataloader = args [1 :3 ]
247
- assert isinstance (train_dataloader .batch_sampler , BatchSampler )
248
- assert isinstance (eval_dataloader .batch_sampler , BatchSampler )
292
+ assert train .call_count == (1 if model_type == "llama" else 0 )
293
+
294
+ if model_type == "llama" :
295
+ args , kwargs = train .call_args
296
+ train_dataloader , eval_dataloader = args [1 :3 ]
297
+ assert isinstance (train_dataloader .batch_sampler , BatchSampler )
298
+ assert isinstance (eval_dataloader .batch_sampler , BatchSampler )
249
299
250
300
kwargs ["batching_strategy" ] = "padding"
251
301
train .reset_mock ()
0 commit comments