diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index b45666749234..5008b9901a39 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -745,11 +745,7 @@ def multistep_uni_p_bh_update( if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: D1s = None @@ -883,11 +879,7 @@ def multistep_uni_c_bh_update( else: D1s = None - # for order 1, we use a simplified version - if order == 1: - rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0