Skip to content

Commit bb2c64a

Browse files
authored
Add the new SD2 attention params to the VD text unet (#1400)
1 parent 05a36d5 commit bb2c64a

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def get_down_block(
2828
resnet_groups=None,
2929
cross_attention_dim=None,
3030
downsample_padding=None,
31-
dual_cross_attention=None,
31+
dual_cross_attention=False,
32+
use_linear_projection=False,
33+
only_cross_attention=False,
3234
):
3335
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
3436
if down_block_type == "DownBlockFlat":
@@ -58,6 +60,9 @@ def get_down_block(
5860
downsample_padding=downsample_padding,
5961
cross_attention_dim=cross_attention_dim,
6062
attn_num_head_channels=attn_num_head_channels,
63+
dual_cross_attention=dual_cross_attention,
64+
use_linear_projection=use_linear_projection,
65+
only_cross_attention=only_cross_attention,
6166
)
6267
raise ValueError(f"{down_block_type} is not supported.")
6368

@@ -75,7 +80,9 @@ def get_up_block(
7580
attn_num_head_channels,
7681
resnet_groups=None,
7782
cross_attention_dim=None,
78-
dual_cross_attention=None,
83+
dual_cross_attention=False,
84+
use_linear_projection=False,
85+
only_cross_attention=False,
7986
):
8087
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
8188
if up_block_type == "UpBlockFlat":
@@ -105,6 +112,9 @@ def get_up_block(
105112
resnet_groups=resnet_groups,
106113
cross_attention_dim=cross_attention_dim,
107114
attn_num_head_channels=attn_num_head_channels,
115+
dual_cross_attention=dual_cross_attention,
116+
use_linear_projection=use_linear_projection,
117+
only_cross_attention=only_cross_attention,
108118
)
109119
raise ValueError(f"{up_block_type} is not supported.")
110120

0 commit comments

Comments
 (0)