@@ -469,12 +469,6 @@ def parse_args(input_args=None):
469469        default = 1e-4 ,
470470        help = "Initial learning rate (after the potential warmup period) to use." ,
471471    )
472-     parser .add_argument (
473-         "--guidance_scale" ,
474-         type = float ,
475-         default = 0.0 ,
476-         help = "Qwen image is a guidance distilled model" ,
477-     )
478472    parser .add_argument (
479473        "--scale_lr" ,
480474        action = "store_true" ,
@@ -1431,10 +1425,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14311425            sigma  =  sigma .unsqueeze (- 1 )
14321426        return  sigma 
14331427
1434-     guidance  =  None 
1435-     if  unwrap_model (transformer ).config .guidance_embeds :
1436-         guidance  =  torch .tensor ([args .guidance_scale ], device = accelerator .device )
1437- 
14381428    for  epoch  in  range (first_epoch , args .num_train_epochs ):
14391429        transformer .train ()
14401430
@@ -1482,11 +1472,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14821472                sigmas  =  get_sigmas (timesteps , n_dim = model_input .ndim , dtype = model_input .dtype )
14831473                noisy_model_input  =  (1.0  -  sigmas ) *  model_input  +  sigmas  *  noise 
14841474
1485-                 # handle guidance 
1486-                 if  guidance  is  not   None :
1487-                     guidance  =  torch .tensor ([args .guidance_scale ], device = accelerator .device )
1488-                     guidance  =  guidance .expand (model_input .shape [0 ])
1489- 
14901475                # Predict the noise residual 
14911476                img_shapes  =  [
14921477                    (1 , args .resolution  //  vae_scale_factor  //  2 , args .resolution  //  vae_scale_factor  //  2 )
@@ -1505,7 +1490,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15051490                    encoder_hidden_states = prompt_embeds ,
15061491                    encoder_hidden_states_mask = prompt_embeds_mask ,
15071492                    timestep = timesteps  /  1000 ,
1508-                     guidance = guidance ,
15091493                    img_shapes = img_shapes ,
15101494                    txt_seq_lens = prompt_embeds_mask .sum (dim = 1 ).tolist (),
15111495                    return_dict = False ,
0 commit comments