-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
Hi, thanks for your great contributions to the community! I’ve been using DiffSynth for a long time. When I look up the code of loss.py, I'm confuse about the algorithm of the direct distillation.
Code:
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
The loss function and pseudo code:
seed = xxx
with torch.no_grad():
image_1 = pipe(prompt, steps=50, seed=seed, cfg=4)
image_2 = pipe(prompt, steps=4, seed=seed, cfg=1)
loss = torch.nn.functional.mse_loss(image_1, image_2)
So I got noise as inputs["latents"] and image latents as inputs["input_latents"] after QwenImageUnit_InputImageEmbedder stage. Why it could be considered as Student-Teacher model ? Why do you consider random noise as student model instead of few step inference result How can we set NFE (few steps like 4 or 8) with training? Thank you for your time, and thanks again for all your work on this project!
Metadata
Metadata
Assignees
Labels
No labels