diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index a8f1396aae52..97d472aa106c 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -18,9 +18,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchvision import transforms from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torchvision_available from ..attention import FeedForward from ..attention_processor import Attention from ..embeddings import Timesteps @@ -29,6 +29,10 @@ from ..normalization import RMSNorm +if is_torchvision_available(): + from torchvision import transforms + + class CosmosPatchEmbed(nn.Module): def __init__( self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True