diff --git a/lib_layerdiffusion/models.py b/lib_layerdiffusion/models.py index 19396f7..f192acd 100644 --- a/lib_layerdiffusion/models.py +++ b/lib_layerdiffusion/models.py @@ -305,7 +305,8 @@ def estimate_augmented(self, pixel, latent): in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median() and then move the result back to gpu. ''' - median = torch.median(result.cpu(), dim=0).values + result_cpu = result.cpu().float() # Convert to float32 on CPU + median = torch.median(result_cpu, dim=0).values median = median.to(device=self.load_device, dtype=self.dtype) else: median = torch.median(result, dim=0).values