diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 4e3902ae6dbe..f25430050ce5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools from typing import Dict, Optional, Tuple, Union import torch @@ -94,7 +95,7 @@ def forward( sample = self.conv_in(sample) - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module):