@@ -199,3 +199,50 @@ def test_fft_schedule_free_adamw(self, temp_dir):
199
199
200
200
train (cfg = cfg , dataset_meta = dataset_meta )
201
201
check_model_output_exists (temp_dir , cfg )
202
+
203
+ @with_temp_dir
204
+ def test_came_pytorch (self , temp_dir ):
205
+ # pylint: disable=duplicate-code
206
+ cfg = DictDefault (
207
+ {
208
+ "base_model" : "JackFram/llama-68m" ,
209
+ "tokenizer_type" : "LlamaTokenizer" ,
210
+ "sequence_len" : 1024 ,
211
+ "load_in_8bit" : True ,
212
+ "adapter" : "lora" ,
213
+ "lora_r" : 8 ,
214
+ "lora_alpha" : 16 ,
215
+ "lora_dropout" : 0.05 ,
216
+ "lora_target_linear" : True ,
217
+ "val_set_size" : 0.1 ,
218
+ "special_tokens" : {
219
+ "unk_token" : "<unk>" ,
220
+ "bos_token" : "<s>" ,
221
+ "eos_token" : "</s>" ,
222
+ },
223
+ "datasets" : [
224
+ {
225
+ "path" : "mhenrichsen/alpaca_2k_test" ,
226
+ "type" : "alpaca" ,
227
+ },
228
+ ],
229
+ "num_epochs" : 1 ,
230
+ "micro_batch_size" : 8 ,
231
+ "gradient_accumulation_steps" : 1 ,
232
+ "output_dir" : temp_dir ,
233
+ "learning_rate" : 0.00001 ,
234
+ "optimizer" : "came_pytorch" ,
235
+ "adam_beta3" : 0.9999 ,
236
+ "adam_epsilon2" : 1e-16 ,
237
+ "max_steps" : 5 ,
238
+ "lr_scheduler" : "cosine" ,
239
+ }
240
+ )
241
+
242
+ cfg = validate_config (cfg )
243
+ normalize_config (cfg )
244
+ cli_args = TrainerCliArgs ()
245
+ dataset_meta = load_datasets (cfg = cfg , cli_args = cli_args )
246
+
247
+ train (cfg = cfg , dataset_meta = dataset_meta )
248
+ check_model_output_exists (temp_dir , cfg )
0 commit comments