File tree Expand file tree Collapse file tree 2 files changed +14
-2
lines changed
examples/research_projects/diffusion_dpo Expand file tree Collapse file tree 2 files changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -414,6 +414,12 @@ def parse_args(input_args=None):
414414 default = 4 ,
415415 help = ("The dimension of the LoRA update matrices." ),
416416 )
417+ parser .add_argument (
418+ "--tracker_name" ,
419+ type = str ,
420+ default = "diffusion-dpo-lora" ,
421+ help = ("The name of the tracker to report results to." ),
422+ )
417423
418424 if input_args is not None :
419425 args = parser .parse_args (input_args )
@@ -726,7 +732,7 @@ def collate_fn(examples):
726732 # We need to initialize the trackers we use, and also store our configuration.
727733 # The trackers initializes automatically on the main process.
728734 if accelerator .is_main_process :
729- accelerator .init_trackers ("diffusion-dpo-lora" , config = vars (args ))
735+ accelerator .init_trackers (args . tracker_name , config = vars (args ))
730736
731737 # Train!
732738 total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
Original file line number Diff line number Diff line change @@ -429,6 +429,12 @@ def parse_args(input_args=None):
429429 default = 4 ,
430430 help = ("The dimension of the LoRA update matrices." ),
431431 )
432+ parser .add_argument (
433+ "--tracker_name" ,
434+ type = str ,
435+ default = "diffusion-dpo-lora-sdxl" ,
436+ help = ("The name of the tracker to report results to." ),
437+ )
432438
433439 if input_args is not None :
434440 args = parser .parse_args (input_args )
@@ -821,7 +827,7 @@ def collate_fn(examples):
821827 # We need to initialize the trackers we use, and also store our configuration.
822828 # The trackers initializes automatically on the main process.
823829 if accelerator .is_main_process :
824- accelerator .init_trackers ("diffusion-dpo-lora-sdxl" , config = vars (args ))
830+ accelerator .init_trackers (args . tracker_name , config = vars (args ))
825831
826832 # Train!
827833 total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
You can’t perform that action at this time.
0 commit comments