11import imageio , os , torch , warnings , torchvision , argparse , json
2+ from ..utils import ModelConfig
3+ from ..models .utils import load_state_dict
24from peft import LoraConfig , inject_adapter_in_model
35from PIL import Image
46import pandas as pd
@@ -424,7 +426,53 @@ def transfer_data_to_device(self, data, device):
424426 if isinstance (data [key ], torch .Tensor ):
425427 data [key ] = data [key ].to (device )
426428 return data
427-
429+
430+
431+ def parse_model_configs (self , model_paths , model_id_with_origin_paths , enable_fp8_training = False ):
432+ offload_dtype = torch .float8_e4m3fn if enable_fp8_training else None
433+ model_configs = []
434+ if model_paths is not None :
435+ model_paths = json .loads (model_paths )
436+ model_configs += [ModelConfig (path = path , offload_dtype = offload_dtype ) for path in model_paths ]
437+ if model_id_with_origin_paths is not None :
438+ model_id_with_origin_paths = model_id_with_origin_paths .split ("," )
439+ model_configs += [ModelConfig (model_id = i .split (":" )[0 ], origin_file_pattern = i .split (":" )[1 ], offload_dtype = offload_dtype ) for i in model_id_with_origin_paths ]
440+ return model_configs
441+
442+
443+ def switch_pipe_to_training_mode (
444+ self ,
445+ pipe ,
446+ trainable_models ,
447+ lora_base_model , lora_target_modules , lora_rank , lora_checkpoint = None ,
448+ enable_fp8_training = False ,
449+ ):
450+ # Scheduler
451+ pipe .scheduler .set_timesteps (1000 , training = True )
452+
453+ # Freeze untrainable models
454+ pipe .freeze_except ([] if trainable_models is None else trainable_models .split ("," ))
455+
456+ # Enable FP8 if pipeline supports
457+ if enable_fp8_training and hasattr (pipe , "_enable_fp8_lora_training" ):
458+ pipe ._enable_fp8_lora_training (torch .float8_e4m3fn )
459+
460+ # Add LoRA to the base models
461+ if lora_base_model is not None :
462+ model = self .add_lora_to_model (
463+ getattr (pipe , lora_base_model ),
464+ target_modules = lora_target_modules .split ("," ),
465+ lora_rank = lora_rank ,
466+ upcast_dtype = pipe .torch_dtype ,
467+ )
468+ if lora_checkpoint is not None :
469+ state_dict = load_state_dict (lora_checkpoint )
470+ state_dict = self .mapping_lora_state_dict (state_dict )
471+ load_result = model .load_state_dict (state_dict , strict = False )
472+ print (f"LoRA checkpoint loaded: { lora_checkpoint } , total { len (state_dict )} keys" )
473+ if len (load_result [1 ]) > 0 :
474+ print (f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: { load_result [1 ]} " )
475+ setattr (pipe , lora_base_model , model )
428476
429477
430478class ModelLogger :
@@ -472,14 +520,26 @@ def launch_training_task(
472520 dataset : torch .utils .data .Dataset ,
473521 model : DiffusionTrainingModule ,
474522 model_logger : ModelLogger ,
475- optimizer : torch . optim . Optimizer ,
476- scheduler : torch . optim . lr_scheduler . LRScheduler ,
523+ learning_rate : float = 1e-5 ,
524+ weight_decay : float = 1e-2 ,
477525 num_workers : int = 8 ,
478526 save_steps : int = None ,
479527 num_epochs : int = 1 ,
480528 gradient_accumulation_steps : int = 1 ,
481529 find_unused_parameters : bool = False ,
530+ args = None ,
482531):
532+ if args is not None :
533+ learning_rate = args .learning_rate
534+ weight_decay = args .weight_decay
535+ num_workers = args .dataset_num_workers
536+ save_steps = args .save_steps
537+ num_epochs = args .num_epochs
538+ gradient_accumulation_steps = args .gradient_accumulation_steps
539+ find_unused_parameters = args .find_unused_parameters
540+
541+ optimizer = torch .optim .AdamW (model .trainable_modules (), lr = learning_rate , weight_decay = weight_decay )
542+ scheduler = torch .optim .lr_scheduler .ConstantLR (optimizer )
483543 dataloader = torch .utils .data .DataLoader (dataset , shuffle = True , collate_fn = lambda x : x [0 ], num_workers = num_workers )
484544 accelerator = Accelerator (
485545 gradient_accumulation_steps = gradient_accumulation_steps ,
@@ -509,8 +569,12 @@ def launch_data_process_task(
509569 model : DiffusionTrainingModule ,
510570 model_logger : ModelLogger ,
511571 num_workers : int = 8 ,
572+ args = None ,
512573):
513- dataloader = torch .utils .data .DataLoader (dataset , shuffle = True , collate_fn = lambda x : x [0 ], num_workers = num_workers )
574+ if args is not None :
575+ num_workers = args .dataset_num_workers
576+
577+ dataloader = torch .utils .data .DataLoader (dataset , shuffle = False , collate_fn = lambda x : x [0 ], num_workers = num_workers )
514578 accelerator = Accelerator ()
515579 model , dataloader = accelerator .prepare (model , dataloader )
516580
@@ -520,7 +584,7 @@ def launch_data_process_task(
520584 folder = os .path .join (model_logger .output_path , str (accelerator .process_index ))
521585 os .makedirs (folder , exist_ok = True )
522586 save_path = os .path .join (model_logger .output_path , str (accelerator .process_index ), f"{ data_id } .pth" )
523- data = model (data )
587+ data = model (data , return_inputs = True )
524588 torch .save (data , save_path )
525589
526590
@@ -623,4 +687,5 @@ def qwen_image_parser():
623687 parser .add_argument ("--weight_decay" , type = float , default = 0.01 , help = "Weight decay." )
624688 parser .add_argument ("--processor_path" , type = str , default = None , help = "Path to the processor. If provided, the processor will be used for image editing." )
625689 parser .add_argument ("--enable_fp8_training" , default = False , action = "store_true" , help = "Whether to enable FP8 training. Only available for LoRA training on a single GPU." )
690+ parser .add_argument ("--task" , type = str , default = "sft" , required = False , help = "Task type." )
626691 return parser
0 commit comments