diff --git a/wanvideo/schedulers/fm_solvers_unipc.py b/wanvideo/schedulers/fm_solvers_unipc.py index 14e854fe..cefd6a50 100644 --- a/wanvideo/schedulers/fm_solvers_unipc.py +++ b/wanvideo/schedulers/fm_solvers_unipc.py @@ -16,6 +16,28 @@ if is_scipy_available(): import scipy.stats +# === sm_120 hotfix: robust wrapper per torch.linalg.solve === +def _safe_solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Sostituto robusto per torch.linalg.solve su GPU recenti (sm_120): + - cast a float32 + - garantisce contiguità e RHS di forma (n,1) se vettore + - fallback CPU se il solver CUDA fallisce (CUSOLVER_STATUS_INTERNAL_ERROR) + Ritorna con il dtype originale di b. + """ + dev = A.device + A32 = A.to(torch.float32).contiguous() + b_is_vec = (b.ndim == 1) + b32 = b.to(torch.float32).contiguous() + if b32.ndim == 1: + b32 = b32.unsqueeze(-1) + try: + x = torch.linalg.solve(A32, b32) + except Exception: + x = torch.linalg.solve(A32.cpu(), b32.cpu()).to(dev) + if b_is_vec: + x = x.squeeze(-1) + return x.to(b.dtype) class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ @@ -494,8 +516,7 @@ def multistep_uni_p_bh_update( if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: - rhos_p = torch.linalg.solve(R[:-1, :-1], - b[:-1]).to(device).to(x.dtype) + rhos_p = _safe_solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: D1s = None @@ -640,7 +661,7 @@ def multistep_uni_c_bh_update( if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: - rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + rhos_c = _safe_solve(R, b).to(device).to(x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0