Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["unets.unet_seva"] = ["SevaUnet"]

if is_flax_available():
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
Expand Down Expand Up @@ -177,6 +178,7 @@
I2VGenXLUNet,
Kandinsky3UNet,
MotionAdapter,
SevaUnet,
StableCascadeUNet,
UNet1DModel,
UNet2DConditionModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .seva_transformer import SevaMultiviewTransformer
103 changes: 103 additions & 0 deletions src/diffusers/models/transformers/seva_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from torch import nn
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock


class SkipConnect(nn.Module):
def __init__(self):
super().__init__()

def forward(
self, x_spatial: torch.Tensor, x_temporal: torch.Tensor
) -> torch.Tensor:
return x_spatial + x_temporal


class SevaMultiviewTransformer(nn.Module):
def __init__(
self,
in_channels: int,
num_heads: int,
dim_head: int,
name: str,
unflatten_names: list[str] = [],
transformer_depth: int = 1,
context_dim: int = 1024,
dropout: float = 0.0,
):
super().__init__()
self.name = name
self.unflatten_names = unflatten_names

self.in_channels = in_channels
inner_dim = num_heads * dim_head
self.norm = nn.GroupNorm(32, in_channels, eps=1e-6)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=inner_dim,
num_attention_heads=num_heads,
attention_head_dim=dim_head,
dropout=dropout,
cross_attention_dim=context_dim,
)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)

self.time_mixer = SkipConnect()
time_mix_inner_dim = inner_dim
self.time_mix_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
dim=inner_dim,
time_mix_inner_dim=time_mix_inner_dim,
num_attention_heads=num_heads,
attention_head_dim=dim_head,
dropout=dropout,
cross_attention_dim=context_dim,
)
for _ in range(transformer_depth)
]
)

def forward(
self, x: torch.Tensor, context_emb: torch.Tensor, num_frames: int
) -> torch.Tensor:
assert context_emb.ndim == 3
_, _, height, width = x.shape

time_context = context_emb
time_context_first_timestep = time_context[::num_frames]
time_context = torch.repeat_interleave(
time_context_first_timestep, height * width, dim=0
)

h = self.norm(x)
h = h.permute(0, 2, 3, 1).contiguous()
h = h.view(h.shape[0], -1, h.shape[3])
h = self.proj_in(h)

for transformer_block, time_mix_block in zip(self.transformer_blocks, self.time_mix_blocks):
if self.name in self.unflatten_names:
context_emb = context_emb[::num_frames]
context_emb = context_emb.view(
context_emb.shape[0]//num_frames, context_emb[1]*num_frames, context_emb.shape[2]
)

h = transformer_block(h, encoder_hidden_states=context_emb)

if self.name in self.unflatten_names:
context_emb = context_emb.view(
context_emb.shape[0]*num_frames, context_emb[1]//num_frames, context_emb.shape[2]
)

h_mix = time_mix_block(h, encoder_hidden_states=time_context, num_frames=num_frames)
h = self.time_mixer(x_spatial=h, x_temporal=h_mix)

h = self.proj_out(h)
h = h.view(h.shape[0], height, width, h.shape[2])
h = h.permute(0, 3, 1, 2).contiguous()
out = h + x
return out
1 change: 1 addition & 0 deletions src/diffusers/models/unets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .unet_stable_cascade import StableCascadeUNet
from .uvit_2d import UVit2DModel
from .unet_seva import SevaUnet


if is_flax_available():
Expand Down
Loading