1616import argparse
1717import copy
1818import itertools
19+ import json
1920import logging
2021import math
2122import os
2728
2829import numpy as np
2930import torch
30- import torch .utils .checkpoint
3131import transformers
3232from accelerate import Accelerator
3333from accelerate .logging import get_logger
3434from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3535from huggingface_hub import create_repo , upload_folder
3636from huggingface_hub .utils import insecure_hashlib
37- from peft import LoraConfig , set_peft_model_state_dict
37+ from peft import LoraConfig , prepare_model_for_kbit_training , set_peft_model_state_dict
3838from peft .utils import get_peft_model_state_dict
3939from PIL import Image
4040from PIL .ImageOps import exif_transpose
4747import diffusers
4848from diffusers import (
4949 AutoencoderKL ,
50+ BitsAndBytesConfig ,
5051 FlowMatchEulerDiscreteScheduler ,
5152 HiDreamImagePipeline ,
5253 HiDreamImageTransformer2DModel ,
@@ -282,6 +283,12 @@ def parse_args(input_args=None):
282283 default = "meta-llama/Meta-Llama-3.1-8B-Instruct" ,
283284 help = "Path to pretrained model or model identifier from huggingface.co/models." ,
284285 )
286+ parser .add_argument (
287+ "--bnb_quantization_config_path" ,
288+ type = str ,
289+ default = None ,
290+ help = "Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT." ,
291+ )
285292 parser .add_argument (
286293 "--revision" ,
287294 type = str ,
@@ -1056,6 +1063,14 @@ def main(args):
10561063 args .pretrained_model_name_or_path , args .revision , subfolder = "text_encoder_3"
10571064 )
10581065
1066+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
1067+ # as these weights are only used for inference, keeping weights in full precision is not required.
1068+ weight_dtype = torch .float32
1069+ if accelerator .mixed_precision == "fp16" :
1070+ weight_dtype = torch .float16
1071+ elif accelerator .mixed_precision == "bf16" :
1072+ weight_dtype = torch .bfloat16
1073+
10591074 # Load scheduler and models
10601075 noise_scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (
10611076 args .pretrained_model_name_or_path , subfolder = "scheduler" , revision = args .revision , shift = 3.0
@@ -1064,20 +1079,30 @@ def main(args):
10641079 text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four = load_text_encoders (
10651080 text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three
10661081 )
1067-
10681082 vae = AutoencoderKL .from_pretrained (
10691083 args .pretrained_model_name_or_path ,
10701084 subfolder = "vae" ,
10711085 revision = args .revision ,
10721086 variant = args .variant ,
10731087 )
1088+ quantization_config = None
1089+ if args .bnb_quantization_config_path is not None :
1090+ with open (args .bnb_quantization_config_path , "r" ) as f :
1091+ config_kwargs = json .load (f )
1092+ config_kwargs ["bnb_4bit_compute_dtype" ] = weight_dtype
1093+ quantization_config = BitsAndBytesConfig (** config_kwargs )
1094+
10741095 transformer = HiDreamImageTransformer2DModel .from_pretrained (
10751096 args .pretrained_model_name_or_path ,
10761097 subfolder = "transformer" ,
10771098 revision = args .revision ,
10781099 variant = args .variant ,
1100+ quantization_config = quantization_config ,
1101+ torch_dtype = weight_dtype ,
10791102 force_inference_output = True ,
10801103 )
1104+ if args .bnb_quantization_config_path is not None :
1105+ transformer = prepare_model_for_kbit_training (transformer , use_gradient_checkpointing = False )
10811106
10821107 # We only train the additional adapter LoRA layers
10831108 transformer .requires_grad_ (False )
@@ -1087,14 +1112,6 @@ def main(args):
10871112 text_encoder_three .requires_grad_ (False )
10881113 text_encoder_four .requires_grad_ (False )
10891114
1090- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
1091- # as these weights are only used for inference, keeping weights in full precision is not required.
1092- weight_dtype = torch .float32
1093- if accelerator .mixed_precision == "fp16" :
1094- weight_dtype = torch .float16
1095- elif accelerator .mixed_precision == "bf16" :
1096- weight_dtype = torch .bfloat16
1097-
10981115 if torch .backends .mps .is_available () and weight_dtype == torch .bfloat16 :
10991116 # due to pytorch#99272, MPS does not yet support bfloat16.
11001117 raise ValueError (
@@ -1109,7 +1126,7 @@ def main(args):
11091126 text_encoder_three .to (** to_kwargs )
11101127 text_encoder_four .to (** to_kwargs )
11111128 # we never offload the transformer to CPU, so we can just use the accelerator device
1112- transformer .to (accelerator .device , dtype = weight_dtype )
1129+ transformer .to (accelerator .device )
11131130
11141131 # Initialize a text encoding pipeline and keep it to CPU for now.
11151132 text_encoding_pipeline = HiDreamImagePipeline .from_pretrained (
0 commit comments