Skip to content

Commit 7ca0b66

Browse files
committed
Fix unit tests
1 parent 853f928 commit 7ca0b66

File tree

1 file changed

+94
-55
lines changed

1 file changed

+94
-55
lines changed

tests/test_finetuning.py

Lines changed: 94 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,56 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4-
import pytest
5-
from pytest import approx
4+
import os
65
from unittest.mock import patch
76

7+
import pytest
8+
89
import torch
10+
from llama_recipes.data.sampler import LengthBasedBatchSampler
11+
12+
from llama_recipes.finetuning import main
13+
from pytest import approx
914
from torch.optim import AdamW
1015
from torch.utils.data.dataloader import DataLoader
1116
from torch.utils.data.sampler import BatchSampler
1217

13-
from llama_recipes.finetuning import main
14-
from llama_recipes.data.sampler import LengthBasedBatchSampler
15-
1618

1719
def get_fake_dataset():
18-
return [{
19-
"input_ids":[1],
20-
"attention_mask":[1],
21-
"labels":[1],
22-
}]
23-
24-
@patch('llama_recipes.finetuning.torch.cuda.is_available')
25-
@patch('llama_recipes.finetuning.train')
26-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
27-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
28-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
29-
@patch('llama_recipes.finetuning.optim.AdamW')
30-
@patch('llama_recipes.finetuning.StepLR')
20+
return [
21+
{
22+
"input_ids": [1],
23+
"attention_mask": [1],
24+
"labels": [1],
25+
}
26+
]
27+
28+
29+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
30+
@patch("llama_recipes.finetuning.train")
31+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
32+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
33+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
34+
@patch("llama_recipes.finetuning.optim.AdamW")
35+
@patch("llama_recipes.finetuning.StepLR")
3136
@pytest.mark.parametrize("cuda_is_available", [True, False])
32-
def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
37+
def test_finetuning_no_validation(
38+
step_lr,
39+
optimizer,
40+
get_dataset,
41+
tokenizer,
42+
get_model,
43+
train,
44+
cuda,
45+
cuda_is_available,
46+
):
3347
kwargs = {"run_validation": False}
3448

3549
get_dataset.return_value = get_fake_dataset()
3650
cuda.return_value = cuda_is_available
3751

52+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
53+
3854
main(**kwargs)
3955

4056
assert train.call_count == 1
@@ -53,20 +69,31 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
5369
assert get_model.return_value.to.call_count == 0
5470

5571

56-
@patch('llama_recipes.finetuning.torch.cuda.is_available')
57-
@patch('llama_recipes.finetuning.train')
58-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
59-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
60-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
61-
@patch('llama_recipes.finetuning.optim.AdamW')
62-
@patch('llama_recipes.finetuning.StepLR')
72+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
73+
@patch("llama_recipes.finetuning.train")
74+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
75+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
76+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
77+
@patch("llama_recipes.finetuning.optim.AdamW")
78+
@patch("llama_recipes.finetuning.StepLR")
6379
@pytest.mark.parametrize("cuda_is_available", [True, False])
64-
def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
80+
def test_finetuning_with_validation(
81+
step_lr,
82+
optimizer,
83+
get_dataset,
84+
tokenizer,
85+
get_model,
86+
train,
87+
cuda,
88+
cuda_is_available,
89+
):
6590
kwargs = {"run_validation": True}
6691

6792
get_dataset.return_value = get_fake_dataset()
6893
cuda.return_value = cuda_is_available
6994

95+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
96+
7097
main(**kwargs)
7198

7299
assert train.call_count == 1
@@ -83,22 +110,36 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
83110
else:
84111
assert get_model.return_value.to.call_count == 0
85112

86-
@patch('llama_recipes.finetuning.torch.cuda.is_available')
87-
@patch('llama_recipes.finetuning.train')
88-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
89-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
90-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
91-
@patch('llama_recipes.finetuning.generate_peft_config')
92-
@patch('llama_recipes.finetuning.get_peft_model')
93-
@patch('llama_recipes.finetuning.optim.AdamW')
94-
@patch('llama_recipes.finetuning.StepLR')
113+
114+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
115+
@patch("llama_recipes.finetuning.train")
116+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
117+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
118+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
119+
@patch("llama_recipes.finetuning.generate_peft_config")
120+
@patch("llama_recipes.finetuning.get_peft_model")
121+
@patch("llama_recipes.finetuning.optim.AdamW")
122+
@patch("llama_recipes.finetuning.StepLR")
95123
@pytest.mark.parametrize("cuda_is_available", [True, False])
96-
def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
124+
def test_finetuning_peft_lora(
125+
step_lr,
126+
optimizer,
127+
get_peft_model,
128+
gen_peft_config,
129+
get_dataset,
130+
tokenizer,
131+
get_model,
132+
train,
133+
cuda,
134+
cuda_is_available,
135+
):
97136
kwargs = {"use_peft": True}
98137

99138
get_dataset.return_value = get_fake_dataset()
100139
cuda.return_value = cuda_is_available
101140

141+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
142+
102143
main(**kwargs)
103144

104145
if cuda_is_available:
@@ -117,7 +158,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
117158
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
118159
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
119160
def test_finetuning_peft_llama_adapter(
120-
get_dataset, tokenizer, get_model, train, setup, get_peft_model, mocker
161+
get_dataset, tokenizer, get_model, train, setup, get_peft_model
121162
):
122163
kwargs = {
123164
"use_peft": True,
@@ -127,11 +168,7 @@ def test_finetuning_peft_llama_adapter(
127168

128169
get_dataset.return_value = get_fake_dataset()
129170

130-
model = mocker.MagicMock(name="Model")
131-
model.parameters.return_value = [torch.ones(1, 1)]
132-
model.get_input_embeddings.return_value.weight.shape = [0]
133-
134-
get_model.return_value = model
171+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
135172

136173
os.environ["RANK"] = "0"
137174
os.environ["LOCAL_RANK"] = "0"
@@ -164,16 +201,14 @@ def test_finetuning_peft_llama_adapter(
164201
@patch("llama_recipes.finetuning.get_peft_model")
165202
@patch("llama_recipes.finetuning.StepLR")
166203
def test_finetuning_weight_decay(
167-
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker
204+
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
168205
):
169206
kwargs = {"weight_decay": 0.01}
170207

171208
get_dataset.return_value = get_fake_dataset()
172209

173-
model = mocker.MagicMock(name="Model")
174-
model.parameters.return_value = [torch.ones(1,1)]
175-
176-
get_model.return_value = model
210+
get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
211+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
177212

178213
main(**kwargs)
179214

@@ -188,17 +223,21 @@ def test_finetuning_weight_decay(
188223
assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
189224

190225

191-
@patch('llama_recipes.finetuning.train')
192-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
193-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
194-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
195-
@patch('llama_recipes.finetuning.optim.AdamW')
196-
@patch('llama_recipes.finetuning.StepLR')
197-
def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
226+
@patch("llama_recipes.finetuning.train")
227+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
228+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
229+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
230+
@patch("llama_recipes.finetuning.optim.AdamW")
231+
@patch("llama_recipes.finetuning.StepLR")
232+
def test_batching_strategy(
233+
step_lr, optimizer, get_dataset, tokenizer, get_model, train
234+
):
198235
kwargs = {"batching_strategy": "packing"}
199236

200237
get_dataset.return_value = get_fake_dataset()
201238

239+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
240+
202241
main(**kwargs)
203242

204243
assert train.call_count == 1

0 commit comments

Comments
 (0)