diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index cf07e45b0c5c..af04ae4b93cf 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -165,6 +165,14 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if self.interpolate: + # upsample_nearest_nhwc also fails when the number of output elements is large + # https://github.com/pytorch/pytorch/issues/141831 + scale_factor = ( + 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])]) + ) + if hidden_states.numel() * scale_factor > pow(2, 31): + hidden_states = hidden_states.contiguous() + if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: