From fcd4550b107e18e156badbd89c9ddba7b72292fe Mon Sep 17 00:00:00 2001 From: Rajeev Rao Date: Fri, 12 Jan 2024 23:37:16 -0800 Subject: [PATCH] Enable ONNX export of GPU loaded SVD/SVD-XT UNet models * Unpack num_frames scalar if created as a (CPU) tensor in forward path Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops Signed-off-by: Rajeev Rao --- src/diffusers/models/unets/unet_spatio_temporal_condition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 39a8009d5af9..36eb72963ec5 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -397,6 +397,8 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML batch_size, num_frames = sample.shape[:2] + if torch.is_tensor(num_frames): + num_frames = num_frames.item() timesteps = timesteps.expand(batch_size) t_emb = self.time_proj(timesteps)