@@ -110,13 +110,62 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
110
110
assert get_peft_model .return_value .print_trainable_parameters .call_count == 1
111
111
112
112
113
- @patch ('llama_recipes.finetuning.train' )
114
- @patch ('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained' )
115
- @patch ('llama_recipes.finetuning.AutoTokenizer.from_pretrained' )
116
- @patch ('llama_recipes.finetuning.get_preprocessed_dataset' )
117
- @patch ('llama_recipes.finetuning.get_peft_model' )
118
- @patch ('llama_recipes.finetuning.StepLR' )
119
- def test_finetuning_weight_decay (step_lr , get_peft_model , get_dataset , tokenizer , get_model , train , mocker ):
113
+ @patch ("llama_recipes.finetuning.get_peft_model" )
114
+ @patch ("llama_recipes.finetuning.setup" )
115
+ @patch ("llama_recipes.finetuning.train" )
116
+ @patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
117
+ @patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
118
+ @patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
119
+ def test_finetuning_peft_llama_adapter (
120
+ get_dataset , tokenizer , get_model , train , setup , get_peft_model , mocker
121
+ ):
122
+ kwargs = {
123
+ "use_peft" : True ,
124
+ "peft_method" : "llama_adapter" ,
125
+ "enable_fsdp" : True ,
126
+ }
127
+
128
+ get_dataset .return_value = get_fake_dataset ()
129
+
130
+ model = mocker .MagicMock (name = "Model" )
131
+ model .parameters .return_value = [torch .ones (1 , 1 )]
132
+ model .get_input_embeddings .return_value .weight .shape = [0 ]
133
+
134
+ get_model .return_value = model
135
+
136
+ os .environ ["RANK" ] = "0"
137
+ os .environ ["LOCAL_RANK" ] = "0"
138
+ os .environ ["WORLD_SIZE" ] = "1"
139
+ os .environ ["MASTER_ADDR" ] = "localhost"
140
+ os .environ ["MASTER_PORT" ] = "12345"
141
+
142
+ with pytest .raises (
143
+ RuntimeError ,
144
+ match = "Llama_adapter is currently not supported in combination with FSDP" ,
145
+ ):
146
+ main (** kwargs )
147
+
148
+ GET_ME_OUT = "Get me out of here"
149
+ get_peft_model .side_effect = RuntimeError (GET_ME_OUT )
150
+
151
+ kwargs ["enable_fsdp" ] = False
152
+
153
+ with pytest .raises (
154
+ RuntimeError ,
155
+ match = GET_ME_OUT ,
156
+ ):
157
+ main (** kwargs )
158
+
159
+
160
+ @patch ("llama_recipes.finetuning.train" )
161
+ @patch ("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained" )
162
+ @patch ("llama_recipes.finetuning.AutoTokenizer.from_pretrained" )
163
+ @patch ("llama_recipes.finetuning.get_preprocessed_dataset" )
164
+ @patch ("llama_recipes.finetuning.get_peft_model" )
165
+ @patch ("llama_recipes.finetuning.StepLR" )
166
+ def test_finetuning_weight_decay (
167
+ step_lr , get_peft_model , get_dataset , tokenizer , get_model , train , mocker
168
+ ):
120
169
kwargs = {"weight_decay" : 0.01 }
121
170
122
171
get_dataset .return_value = get_fake_dataset ()
0 commit comments