Skip to content

Commit cb1b6b4

Browse files
committed
up
1 parent 987766e commit cb1b6b4

File tree

1 file changed

+0
-16
lines changed

1 file changed

+0
-16
lines changed

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,6 @@ def parse_args(input_args=None):
469469
default=1e-4,
470470
help="Initial learning rate (after the potential warmup period) to use.",
471471
)
472-
parser.add_argument(
473-
"--guidance_scale",
474-
type=float,
475-
default=0.0,
476-
help="Qwen image is a guidance distilled model",
477-
)
478472
parser.add_argument(
479473
"--scale_lr",
480474
action="store_true",
@@ -1431,10 +1425,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14311425
sigma = sigma.unsqueeze(-1)
14321426
return sigma
14331427

1434-
guidance = None
1435-
if unwrap_model(transformer).config.guidance_embeds:
1436-
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1437-
14381428
for epoch in range(first_epoch, args.num_train_epochs):
14391429
transformer.train()
14401430

@@ -1482,11 +1472,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14821472
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
14831473
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
14841474

1485-
# handle guidance
1486-
if guidance is not None:
1487-
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1488-
guidance = guidance.expand(model_input.shape[0])
1489-
14901475
# Predict the noise residual
14911476
img_shapes = [
14921477
(1, args.resolution // vae_scale_factor // 2, args.resolution // vae_scale_factor // 2)
@@ -1505,7 +1490,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15051490
encoder_hidden_states=prompt_embeds,
15061491
encoder_hidden_states_mask=prompt_embeds_mask,
15071492
timestep=timesteps / 1000,
1508-
guidance=guidance,
15091493
img_shapes=img_shapes,
15101494
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
15111495
return_dict=False,

0 commit comments

Comments
 (0)