Skip to content

Commit 853f928

Browse files
committed
Add unit test for config check
1 parent 091d58d commit 853f928

File tree

1 file changed

+56
-7
lines changed

1 file changed

+56
-7
lines changed

tests/test_finetuning.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,62 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
110110
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
111111

112112

113-
@patch('llama_recipes.finetuning.train')
114-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
115-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
116-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
117-
@patch('llama_recipes.finetuning.get_peft_model')
118-
@patch('llama_recipes.finetuning.StepLR')
119-
def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
113+
@patch("llama_recipes.finetuning.get_peft_model")
114+
@patch("llama_recipes.finetuning.setup")
115+
@patch("llama_recipes.finetuning.train")
116+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
117+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
118+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
119+
def test_finetuning_peft_llama_adapter(
120+
get_dataset, tokenizer, get_model, train, setup, get_peft_model, mocker
121+
):
122+
kwargs = {
123+
"use_peft": True,
124+
"peft_method": "llama_adapter",
125+
"enable_fsdp": True,
126+
}
127+
128+
get_dataset.return_value = get_fake_dataset()
129+
130+
model = mocker.MagicMock(name="Model")
131+
model.parameters.return_value = [torch.ones(1, 1)]
132+
model.get_input_embeddings.return_value.weight.shape = [0]
133+
134+
get_model.return_value = model
135+
136+
os.environ["RANK"] = "0"
137+
os.environ["LOCAL_RANK"] = "0"
138+
os.environ["WORLD_SIZE"] = "1"
139+
os.environ["MASTER_ADDR"] = "localhost"
140+
os.environ["MASTER_PORT"] = "12345"
141+
142+
with pytest.raises(
143+
RuntimeError,
144+
match="Llama_adapter is currently not supported in combination with FSDP",
145+
):
146+
main(**kwargs)
147+
148+
GET_ME_OUT = "Get me out of here"
149+
get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
150+
151+
kwargs["enable_fsdp"] = False
152+
153+
with pytest.raises(
154+
RuntimeError,
155+
match=GET_ME_OUT,
156+
):
157+
main(**kwargs)
158+
159+
160+
@patch("llama_recipes.finetuning.train")
161+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
162+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
163+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
164+
@patch("llama_recipes.finetuning.get_peft_model")
165+
@patch("llama_recipes.finetuning.StepLR")
166+
def test_finetuning_weight_decay(
167+
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker
168+
):
120169
kwargs = {"weight_decay": 0.01}
121170

122171
get_dataset.return_value = get_fake_dataset()

0 commit comments

Comments
 (0)