-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Dear authors,
Very excellent work! When I read the implementation code, I have found the code is slightly different from pseudocode. Take bfnsolver++1 as an example. Does the x_t in code is the variable "μ_i" in pseudocode? Then does the noise_pred is "\hat{x_i}"? If so, in the pseucode, why is there an equation between \hat{x_i} and epsilon?
def ode_bfnsolver1_update(self, x_s, step, last_drop=False):
# x_s -> x_t
t = torch.ones_like(x_s, device=x_s.device) * (1 - self.times[step])
# noise predict and x0 predict
with torch.no_grad():
noise_pred = self.unet(x_s, t).reshape(x_s.shape)
alpha_t, sigma_t = self.alpha_t[step], self.sigma_t[step]
x0_pred = (x_s - sigma_t * noise_pred) / alpha_t
# clip x0
x0_pred = x0_pred.clip(min=-1.0, max=1.0)
noise_pred = (x_s - x0_pred * alpha_t) / sigma_t
# get schedule
lambda_t, lambda_s = self.lambda_t[step + 1], self.lambda_t[step]
alpha_t, alpha_s = self.alpha_t[step + 1], self.alpha_t[step]
sigma_t, sigma_s = self.sigma_t[step + 1], self.sigma_t[step]
h = lambda_t - lambda_s
if last_drop == True and step == self.num_steps - 1:
return x0_pred, x0_pred
else:
x_t = (alpha_t / alpha_s) * x_s - (sigma_t * (torch.exp(h) - 1.0)) * noise_pred
return x_t, x0_pred
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels
