Skip to content

Details on direct distillation #1181

@Kaihua1203

Description

@Kaihua1203

Hi, thanks for your great contributions to the community! I’ve been using DiffSynth for a long time. When I look up the code of loss.py, I'm confuse about the algorithm of the direct distillation.

Code:

def DirectDistillLoss(pipe: BasePipeline, **inputs):
    pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
    pipe.scheduler.training = True
    models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
    for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
        timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
        noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
        inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
    loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
    return loss

The loss function and pseudo code:

$$ \mathcal{L}_{\text{direct}} = \text{MSE}(\boldsymbol{z}_{\text{student}}, \boldsymbol{z}_{\text{teacher}}) $$

seed = xxx
with torch.no_grad():
    image_1 = pipe(prompt, steps=50, seed=seed, cfg=4)
image_2 = pipe(prompt, steps=4, seed=seed, cfg=1)
loss = torch.nn.functional.mse_loss(image_1, image_2)

So I got noise as inputs["latents"] and image latents as inputs["input_latents"] after QwenImageUnit_InputImageEmbedder stage. Why it could be considered as Student-Teacher model ? Why do you consider random noise as student model instead of few step inference result How can we set NFE (few steps like 4 or 8) with training? Thank you for your time, and thanks again for all your work on this project!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions