Skip to content

Commit 19ab04f

Browse files
authored
UniPC Multistep fix tensor dtype/device on order=3 (#7532)
* UniPC UTs iterate solvers on FP16 It wasn't catching errs on order==3. Might be excessive? * UniPC Multistep fix tensor dtype/device on order=3 * UniPC UTs Add v_pred to fp16 test iter For completions sake. Probably overkill?
1 parent 4a34307 commit 19ab04f

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def multistep_uni_p_bh_update(
576576
if order == 2:
577577
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
578578
else:
579-
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
579+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
580580
else:
581581
D1s = None
582582

@@ -714,7 +714,7 @@ def multistep_uni_c_bh_update(
714714
if order == 1:
715715
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
716716
else:
717-
rhos_c = torch.linalg.solve(R, b)
717+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
718718

719719
if self.predict_x0:
720720
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0

tests/schedulers/test_scheduler_unipc.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,29 @@ def test_full_loop_with_karras_and_v_prediction(self):
229229
assert abs(result_mean.item() - 0.1966) < 1e-3
230230

231231
def test_fp16_support(self):
232-
scheduler_class = self.scheduler_classes[0]
233-
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
234-
scheduler = scheduler_class(**scheduler_config)
232+
for order in [1, 2, 3]:
233+
for solver_type in ["bh1", "bh2"]:
234+
for prediction_type in ["epsilon", "sample", "v_prediction"]:
235+
scheduler_class = self.scheduler_classes[0]
236+
scheduler_config = self.get_scheduler_config(
237+
thresholding=True,
238+
dynamic_thresholding_ratio=0,
239+
prediction_type=prediction_type,
240+
solver_order=order,
241+
solver_type=solver_type,
242+
)
243+
scheduler = scheduler_class(**scheduler_config)
235244

236-
num_inference_steps = 10
237-
model = self.dummy_model()
238-
sample = self.dummy_sample_deter.half()
239-
scheduler.set_timesteps(num_inference_steps)
245+
num_inference_steps = 10
246+
model = self.dummy_model()
247+
sample = self.dummy_sample_deter.half()
248+
scheduler.set_timesteps(num_inference_steps)
240249

241-
for i, t in enumerate(scheduler.timesteps):
242-
residual = model(sample, t)
243-
sample = scheduler.step(residual, t, sample).prev_sample
250+
for i, t in enumerate(scheduler.timesteps):
251+
residual = model(sample, t)
252+
sample = scheduler.step(residual, t, sample).prev_sample
244253

245-
assert sample.dtype == torch.float16
254+
assert sample.dtype == torch.float16
246255

247256
def test_full_loop_with_noise(self):
248257
scheduler_class = self.scheduler_classes[0]

0 commit comments

Comments
 (0)