Skip to content

Commit 1090ccf

Browse files
committed
Fix test_finetuning
1 parent 9f52006 commit 1090ccf

File tree

3 files changed

+156
-101
lines changed

3 files changed

+156
-101
lines changed

src/llama_recipes/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def main(**kwargs):
287287
)
288288
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
289289
if len(eval_dataloader) == 0:
290-
raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
290+
raise ValueError(f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})")
291291
else:
292292
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
293293

src/tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def llama_version(request):
1313
return request.param
1414

1515

16+
@pytest.fixture(params=["mllama", "llama"])
17+
def model_type(request):
18+
return request.param
19+
20+
1621
@pytest.fixture(scope="module")
1722
def llama_tokenizer(request):
1823
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}

src/tests/test_finetuning.py

Lines changed: 150 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import os
5+
from contextlib import nullcontext
6+
from dataclasses import dataclass
57
from unittest.mock import patch
68

79
import pytest
@@ -16,8 +18,12 @@
1618
from torch.utils.data.sampler import BatchSampler
1719

1820

21+
@dataclass
22+
class Config:
23+
model_type: str = "llama"
24+
1925
def get_fake_dataset():
20-
return [
26+
return 8192*[
2127
{
2228
"input_ids": [1],
2329
"attention_mask": [1],
@@ -28,28 +34,49 @@ def get_fake_dataset():
2834

2935
@patch("llama_recipes.finetuning.torch.cuda.is_available")
3036
@patch("llama_recipes.finetuning.train")
37+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
38+
@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
3139
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
40+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
3241
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
3342
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
43+
@patch("llama_recipes.finetuning.generate_peft_config")
44+
@patch("llama_recipes.finetuning.get_peft_model")
3445
@patch("llama_recipes.finetuning.optim.AdamW")
3546
@patch("llama_recipes.finetuning.StepLR")
3647
@pytest.mark.parametrize("cuda_is_available", [True, False])
37-
def test_finetuning_no_validation(
48+
@pytest.mark.parametrize("run_validation", [True, False])
49+
@pytest.mark.parametrize("use_peft", [True, False])
50+
def test_finetuning(
3851
step_lr,
3952
optimizer,
53+
get_peft_model,
54+
gen_peft_config,
4055
get_dataset,
4156
tokenizer,
57+
get_config,
4258
get_model,
59+
get_processor,
60+
get_mmodel,
4361
train,
4462
cuda,
4563
cuda_is_available,
64+
run_validation,
65+
use_peft,
66+
model_type,
4667
):
47-
kwargs = {"run_validation": False}
68+
kwargs = {
69+
"run_validation": run_validation,
70+
"use_peft": use_peft,
71+
"batching_strategy": "packing" if model_type == "llama" else "padding",
72+
}
4873

4974
get_dataset.return_value = get_fake_dataset()
5075
cuda.return_value = cuda_is_available
5176

5277
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
78+
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
79+
get_config.return_value = Config(model_type=model_type)
5380

5481
main(**kwargs)
5582

@@ -60,115 +87,99 @@ def test_finetuning_no_validation(
6087
eval_dataloader = args[2]
6188

6289
assert isinstance(train_dataloader, DataLoader)
63-
assert eval_dataloader is None
64-
65-
if cuda_is_available:
66-
assert get_model.return_value.to.call_count == 1
67-
assert get_model.return_value.to.call_args.args[0] == "cuda"
90+
if run_validation:
91+
assert isinstance(eval_dataloader, DataLoader)
6892
else:
69-
assert get_model.return_value.to.call_count == 0
70-
71-
72-
@patch("llama_recipes.finetuning.torch.cuda.is_available")
73-
@patch("llama_recipes.finetuning.train")
74-
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
75-
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
76-
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
77-
@patch("llama_recipes.finetuning.optim.AdamW")
78-
@patch("llama_recipes.finetuning.StepLR")
79-
@pytest.mark.parametrize("cuda_is_available", [True, False])
80-
def test_finetuning_with_validation(
81-
step_lr,
82-
optimizer,
83-
get_dataset,
84-
tokenizer,
85-
get_model,
86-
train,
87-
cuda,
88-
cuda_is_available,
89-
):
90-
kwargs = {"run_validation": True}
91-
92-
get_dataset.return_value = get_fake_dataset()
93-
cuda.return_value = cuda_is_available
94-
95-
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
96-
97-
main(**kwargs)
98-
99-
assert train.call_count == 1
100-
101-
args, kwargs = train.call_args
102-
train_dataloader = args[1]
103-
eval_dataloader = args[2]
104-
assert isinstance(train_dataloader, DataLoader)
105-
assert isinstance(eval_dataloader, DataLoader)
93+
assert eval_dataloader is None
10694

107-
if cuda_is_available:
108-
assert get_model.return_value.to.call_count == 1
109-
assert get_model.return_value.to.call_args.args[0] == "cuda"
95+
if use_peft:
96+
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
97+
model = get_peft_model
98+
elif model_type == "llama":
99+
model = get_model
110100
else:
111-
assert get_model.return_value.to.call_count == 0
112-
113-
114-
@patch("llama_recipes.finetuning.torch.cuda.is_available")
115-
@patch("llama_recipes.finetuning.train")
116-
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
117-
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
118-
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
119-
@patch("llama_recipes.finetuning.generate_peft_config")
120-
@patch("llama_recipes.finetuning.get_peft_model")
121-
@patch("llama_recipes.finetuning.optim.AdamW")
122-
@patch("llama_recipes.finetuning.StepLR")
123-
@pytest.mark.parametrize("cuda_is_available", [True, False])
124-
def test_finetuning_peft_lora(
125-
step_lr,
126-
optimizer,
127-
get_peft_model,
128-
gen_peft_config,
129-
get_dataset,
130-
tokenizer,
131-
get_model,
132-
train,
133-
cuda,
134-
cuda_is_available,
135-
):
136-
kwargs = {"use_peft": True}
137-
138-
get_dataset.return_value = get_fake_dataset()
139-
cuda.return_value = cuda_is_available
140-
141-
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
142-
143-
main(**kwargs)
101+
model = get_mmodel
144102

145103
if cuda_is_available:
146-
assert get_peft_model.return_value.to.call_count == 1
147-
assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
104+
assert model.return_value.to.call_count == 1
105+
assert model.return_value.to.call_args.args[0] == "cuda"
148106
else:
149-
assert get_peft_model.return_value.to.call_count == 0
150-
151-
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
107+
assert model.return_value.to.call_count == 0
108+
109+
110+
# @patch("llama_recipes.finetuning.torch.cuda.is_available")
111+
# @patch("llama_recipes.finetuning.train")
112+
# @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
113+
# @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
114+
# @patch("llama_recipes.finetuning.get_preprocessed_dataset")
115+
# @patch("llama_recipes.finetuning.generate_peft_config")
116+
# @patch("llama_recipes.finetuning.get_peft_model")
117+
# @patch("llama_recipes.finetuning.optim.AdamW")
118+
# @patch("llama_recipes.finetuning.StepLR")
119+
# @pytest.mark.parametrize("cuda_is_available", [True, False])
120+
# def test_finetuning_peft_lora(
121+
# step_lr,
122+
# optimizer,
123+
# get_peft_model,
124+
# gen_peft_config,
125+
# get_dataset,
126+
# tokenizer,
127+
# get_model,
128+
# train,
129+
# cuda,
130+
# cuda_is_available,
131+
# ):
132+
# kwargs = {"use_peft": True}
133+
134+
# get_dataset.return_value = get_fake_dataset()
135+
# cuda.return_value = cuda_is_available
136+
137+
# get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
138+
139+
# main(**kwargs)
140+
141+
# if cuda_is_available:
142+
# assert get_peft_model.return_value.to.call_count == 1
143+
# assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
144+
# else:
145+
# assert get_peft_model.return_value.to.call_count == 0
146+
147+
152148

153149

154150
@patch("llama_recipes.finetuning.get_peft_model")
155151
@patch("llama_recipes.finetuning.setup")
156152
@patch("llama_recipes.finetuning.train")
153+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
154+
@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
157155
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
156+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
158157
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
159158
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
160159
def test_finetuning_peft_llama_adapter(
161-
get_dataset, tokenizer, get_model, train, setup, get_peft_model
160+
get_dataset,
161+
tokenizer,
162+
get_config,
163+
get_model,
164+
get_processor,
165+
get_mmodel,
166+
train,
167+
setup,
168+
get_peft_model,
169+
model_type,
162170
):
163171
kwargs = {
164172
"use_peft": True,
165173
"peft_method": "llama_adapter",
166174
"enable_fsdp": True,
175+
"batching_strategy": "packing" if model_type == "llama" else "padding",
167176
}
168177

169178
get_dataset.return_value = get_fake_dataset()
170179

171180
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
181+
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
182+
get_config.return_value = Config(model_type=model_type)
172183

173184
os.environ["RANK"] = "0"
174185
os.environ["LOCAL_RANK"] = "0"
@@ -195,20 +206,38 @@ def test_finetuning_peft_llama_adapter(
195206

196207

197208
@patch("llama_recipes.finetuning.train")
209+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
210+
@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
198211
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
212+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
199213
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
200214
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
201215
@patch("llama_recipes.finetuning.get_peft_model")
202216
@patch("llama_recipes.finetuning.StepLR")
203217
def test_finetuning_weight_decay(
204-
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
218+
step_lr,
219+
get_peft_model,
220+
get_dataset,
221+
tokenizer,
222+
get_config,
223+
get_model,
224+
get_processor,
225+
get_mmodel,
226+
train,
227+
model_type,
205228
):
206-
kwargs = {"weight_decay": 0.01}
229+
kwargs = {
230+
"weight_decay": 0.01,
231+
"batching_strategy": "packing" if model_type == "llama" else "padding",
232+
}
207233

208234
get_dataset.return_value = get_fake_dataset()
209235

210-
get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
211-
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
236+
model = get_model if model_type == "llama" else get_mmodel
237+
model.return_value.parameters.return_value = [torch.ones(1, 1)]
238+
model.return_value.get_input_embeddings.return_value.weight.shape = [0]
239+
240+
get_config.return_value = Config(model_type=model_type)
212241

213242
main(**kwargs)
214243

@@ -224,28 +253,49 @@ def test_finetuning_weight_decay(
224253

225254

226255
@patch("llama_recipes.finetuning.train")
256+
@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
257+
@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
227258
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
259+
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
228260
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
229261
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
230262
@patch("llama_recipes.finetuning.optim.AdamW")
231263
@patch("llama_recipes.finetuning.StepLR")
232264
def test_batching_strategy(
233-
step_lr, optimizer, get_dataset, tokenizer, get_model, train
265+
step_lr,
266+
optimizer,
267+
get_dataset,
268+
tokenizer,
269+
get_config,
270+
get_model,
271+
get_processor,
272+
get_mmodel,
273+
train,
274+
model_type,
234275
):
235-
kwargs = {"batching_strategy": "packing"}
276+
kwargs = {
277+
"batching_strategy": "packing",
278+
}
236279

237280
get_dataset.return_value = get_fake_dataset()
238281

239-
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
282+
model = get_model if model_type == "llama" else get_mmodel
283+
model.return_value.get_input_embeddings.return_value.weight.shape = [0]
240284

241-
main(**kwargs)
285+
get_config.return_value = Config(model_type=model_type)
242286

243-
assert train.call_count == 1
287+
c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
288+
289+
with c:
290+
main(**kwargs)
244291

245-
args, kwargs = train.call_args
246-
train_dataloader, eval_dataloader = args[1:3]
247-
assert isinstance(train_dataloader.batch_sampler, BatchSampler)
248-
assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
292+
assert train.call_count == (1 if model_type == "llama" else 0)
293+
294+
if model_type == "llama":
295+
args, kwargs = train.call_args
296+
train_dataloader, eval_dataloader = args[1:3]
297+
assert isinstance(train_dataloader.batch_sampler, BatchSampler)
298+
assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
249299

250300
kwargs["batching_strategy"] = "padding"
251301
train.reset_mock()

0 commit comments

Comments
 (0)