Skip to content

Inconsistent interpretation of model output between self-conditioning step and prediction step #15

@jsternabsci

Description

@jsternabsci

I think there is a bug in the interface self-conditioning in rin_pytorch.py.

The model output is interpreted differently during the self-conditioning stage compared to the prediction stage.

Currently we have (pseudocode):

self_cond = x0_to_target_modification(model_output)           # Treat model prediction as x0 convert it to x0, eps, or v
pred = self.model(..., self_cond, ...)                        # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

In the current implementation, the interface prediction is interpreted as x0 during self-conditioning, but as the target (x0, eps, or v) at the prediction step.

I see two ways that we could do interface self-conditioning that would be consistent.

We could either:

  • make the model always predict x0. Then we would have
self_cond = model_output                                      # Where self_cond is a prediction for x0
pred = self.model(..., self_cond, ...)                        # Self-condition on x0 prediction and predict x0
pred = x0_to_target_modification(pred).                       # Convert x0 prediction into prediction for x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

or

  • make the model always predict what it is intended to predict (x0, eps, or v). Then we would have
self_cond = model_output                                      # Where self_cond is a prediction for the target (x0, eps, or v)
pred = self.model(..., self_cond, ...)                        # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

In contrast to the current implementation, in my two proposals, the interpretation of the interface prediction is the same between the self-conditioning step and the prediction step. Would you agree that there is inconsistency here and that either of these proposals solves it?

Here is the current code:

if random() < self.train_prob_self_cond:
    with torch.no_grad():
        model_output, self_latents = self.model(noised_img, times, return_latents = True)
        self_latents = self_latents.detach()

        if self.objective == 'x0':
            self_cond = model_output

        elif self.objective == 'eps':
            self_cond = safe_div(noised_img - sigma * model_output, alpha)

        elif self.objective == 'v':
            self_cond = alpha * noised_img - sigma * model_output

        self_cond.clamp_(-1., 1.)
        self_cond = self_cond.detach()

# predict and take gradient step

pred = self.model(noised_img, times, self_cond, self_latents)

...

loss = F.mse_loss(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions