-
Notifications
You must be signed in to change notification settings - Fork 15
Description
I think there is a bug in the interface self-conditioning in rin_pytorch.py.
The model output is interpreted differently during the self-conditioning stage compared to the prediction stage.
Currently we have (pseudocode):
self_cond = x0_to_target_modification(model_output) # Treat model prediction as x0 convert it to x0, eps, or v
pred = self.model(..., self_cond, ...) # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)
In the current implementation, the interface prediction is interpreted as x0 during self-conditioning, but as the target (x0, eps, or v) at the prediction step.
I see two ways that we could do interface self-conditioning that would be consistent.
We could either:
- make the model always predict x0. Then we would have
self_cond = model_output # Where self_cond is a prediction for x0
pred = self.model(..., self_cond, ...) # Self-condition on x0 prediction and predict x0
pred = x0_to_target_modification(pred). # Convert x0 prediction into prediction for x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)
or
- make the model always predict what it is intended to predict (x0, eps, or v). Then we would have
self_cond = model_output # Where self_cond is a prediction for the target (x0, eps, or v)
pred = self.model(..., self_cond, ...) # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)
In contrast to the current implementation, in my two proposals, the interpretation of the interface prediction is the same between the self-conditioning step and the prediction step. Would you agree that there is inconsistency here and that either of these proposals solves it?
Here is the current code:
if random() < self.train_prob_self_cond:
with torch.no_grad():
model_output, self_latents = self.model(noised_img, times, return_latents = True)
self_latents = self_latents.detach()
if self.objective == 'x0':
self_cond = model_output
elif self.objective == 'eps':
self_cond = safe_div(noised_img - sigma * model_output, alpha)
elif self.objective == 'v':
self_cond = alpha * noised_img - sigma * model_output
self_cond.clamp_(-1., 1.)
self_cond = self_cond.detach()
# predict and take gradient step
pred = self.model(noised_img, times, self_cond, self_latents)
...
loss = F.mse_loss(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')