3737from tqdm .auto import tqdm
3838
3939import diffusers
40- from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler , CogView4ControlPipeline ,CogView4Transformer2DModel
40+ from diffusers import (
41+ AutoencoderKL ,
42+ FlowMatchEulerDiscreteScheduler ,
43+ CogView4ControlPipeline ,
44+ CogView4Transformer2DModel ,
45+ )
4146from diffusers .optimization import get_scheduler
4247from diffusers .training_utils import (
4348 compute_density_for_timestep_sampling ,
@@ -787,7 +792,7 @@ def main(args):
787792
788793 # enable image inputs
789794 with torch .no_grad ():
790- patch_size = cogview4_transformer .config .patch_size
795+ patch_size = cogview4_transformer .config .patch_size
791796 initial_input_channels = cogview4_transformer .config .in_channels * patch_size ** 2
792797 new_linear = torch .nn .Linear (
793798 cogview4_transformer .patch_embed .proj .in_features * 2 ,
@@ -803,7 +808,9 @@ def main(args):
803808 cogview4_transformer .patch_embed .proj = new_linear
804809
805810 assert torch .all (cogview4_transformer .patch_embed .proj .weight [:, initial_input_channels :].data == 0 )
806- cogview4_transformer .register_to_config (in_channels = cogview4_transformer .config .in_channels * 2 , out_channels = cogview4_transformer .config .in_channels )
811+ cogview4_transformer .register_to_config (
812+ in_channels = cogview4_transformer .config .in_channels * 2 , out_channels = cogview4_transformer .config .in_channels
813+ )
807814
808815 if args .only_target_transformer_blocks :
809816 cogview4_transformer .patch_embed .proj .requires_grad_ (True )
@@ -1050,34 +1057,41 @@ def load_model_hook(models, input_dir):
10501057 )
10511058
10521059 # Add noise according for cogview4
1053- #FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one.
1060+ # FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one.
10541061 indices = (u * noise_scheduler_copy .config .num_train_timesteps ).long ()
10551062 timesteps = noise_scheduler_copy .timesteps [indices ].to (device = pixel_latents .device )
10561063 sigmas = noise_scheduler_copy .sigmas [indices ].to (device = pixel_latents .device )
10571064 captions = batch ["captions" ]
1058- image_seq_lens = torch .tensor (pixel_latents .shape [2 ] * pixel_latents .shape [3 ] // patch_size ** 2 , dtype = pixel_latents .dtype , device = pixel_latents .device ) # H * W / VAE patch_size
1065+ image_seq_lens = torch .tensor (
1066+ pixel_latents .shape [2 ] * pixel_latents .shape [3 ] // patch_size ** 2 ,
1067+ dtype = pixel_latents .dtype ,
1068+ device = pixel_latents .device ,
1069+ ) # H * W / VAE patch_size
10591070 mu = torch .sqrt (image_seq_lens / 256 )
10601071 mu = mu * 0.75 + 0.25
1061- scale_factors = mu / (mu + (1 / sigmas - 1 ) ** 1.0 ).to (dtype = pixel_latents .dtype , device = pixel_latents .device )
1072+ scale_factors = mu / (mu + (1 / sigmas - 1 ) ** 1.0 ).to (
1073+ dtype = pixel_latents .dtype , device = pixel_latents .device
1074+ )
10621075 scale_factors = scale_factors .view (len (batch ["captions" ]), 1 , 1 , 1 )
10631076 noisy_model_input = (1.0 - scale_factors ) * pixel_latents + scale_factors * noise
10641077 concatenated_noisy_model_input = torch .cat ([noisy_model_input , control_latents ], dim = 1 )
10651078 text_encoding_pipeline = text_encoding_pipeline .to ("cuda" )
10661079
10671080 with torch .no_grad ():
1068- prompt_embeds , pooled_prompt_embeds , = text_encoding_pipeline .encode_prompt (
1069- captions , ""
1070- )
1081+ (
1082+ prompt_embeds ,
1083+ pooled_prompt_embeds ,
1084+ ) = text_encoding_pipeline .encode_prompt (captions , "" )
10711085 original_size = (args .resolution , args .resolution )
10721086 original_size = torch .tensor ([original_size ], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
10731087
1074- target_size = (args .resolution ,args .resolution )
1088+ target_size = (args .resolution , args .resolution )
10751089 target_size = torch .tensor ([target_size ], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
10761090
10771091 target_size = target_size .repeat (len (batch ["captions" ]), 1 )
10781092 original_size = original_size .repeat (len (batch ["captions" ]), 1 )
10791093
1080- #TODO: Should a parameter be set here for passing? This is not present in Flux.
1094+ # TODO: Should a parameter be set here for passing? This is not present in Flux.
10811095 crops_coords_top_left = torch .tensor ([(0 , 0 )], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
10821096 crops_coords_top_left = crops_coords_top_left .repeat (len (batch ["captions" ]), 1 )
10831097 # Predict.
@@ -1108,7 +1122,9 @@ def load_model_hook(models, input_dir):
11081122 target = noise - pixel_latents
11091123
11101124 weighting = weighting .view (len (batch ["captions" ]), 1 , 1 , 1 )
1111- loss = torch .mean ((weighting .float () * (model_pred .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ),1 )
1125+ loss = torch .mean (
1126+ (weighting .float () * (model_pred .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ), 1
1127+ )
11121128 loss = loss .mean ()
11131129 accelerator .backward (loss )
11141130
0 commit comments