-
Couldn't load subscription status.
- Fork 6.5k
Open
Labels
Description
Is your feature request related to a problem? Please describe.
Currently, the upstream PixArt trainer does this:
transformer = get_peft_model(transformer, lora_config)
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(transformer, dtype=torch.float32)
transformer.print_trainable_parameters()
# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_ = accelerator.unwrap_model(transformer)
lora_state_dict = get_peft_model_state_dict(transformer_, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora"), lora_state_dict)
# save weights in peft format to be able to load them back
transformer_.save_pretrained(output_dir)
for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()Describe the solution you'd like.
I would like for the PixArtSigmaPipeline to have the necessary mix-ins and inference support instead.
Describe alternatives you've considered.
I have considered avoiding the pipeline methods and using workarounds like upstream does, but instead I would prefer consistency for user experience.