Skip to content

Commit c5ce24f

Browse files
committed
rewrite memory count without implicitly using dimensions by @ic-synth
1 parent 158a5a8 commit c5ce24f

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d):
4141
"""
4242

4343
def forward(self, input: torch.Tensor) -> torch.Tensor:
44-
memory_count = (
45-
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
46-
)
44+
memory_count = torch.prod(torch.tensor(input.shape)) * 2 / 1024**3
4745

4846
# Set to 2GB, suitable for CuDNN
4947
if memory_count > 2:

0 commit comments

Comments
 (0)