Skip to content

Commit cb0baf8

Browse files
authored
Merge branch 'main' into xformers_flux
2 parents ca45902 + efb7a29 commit cb0baf8

18 files changed

+856
-34
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@
495495
"LTXImageToVideoPipeline",
496496
"LTXLatentUpsamplePipeline",
497497
"LTXPipeline",
498+
"LucyEditPipeline",
498499
"Lumina2Pipeline",
499500
"Lumina2Text2ImgPipeline",
500501
"LuminaPipeline",
@@ -1149,6 +1150,7 @@
11491150
LTXImageToVideoPipeline,
11501151
LTXLatentUpsamplePipeline,
11511152
LTXPipeline,
1153+
LucyEditPipeline,
11521154
Lumina2Pipeline,
11531155
Lumina2Text2ImgPipeline,
11541156
LuminaPipeline,

src/diffusers/models/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def forward(
674674
encoder_hidden_states: torch.FloatTensor,
675675
temb: torch.FloatTensor,
676676
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
677-
):
677+
) -> Tuple[torch.Tensor, torch.Tensor]:
678678
joint_attention_kwargs = joint_attention_kwargs or {}
679679
if self.use_dual_attention:
680680
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ def __init__(
10521052
is_residual=is_residual,
10531053
)
10541054

1055-
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
1055+
self.spatial_compression_ratio = scale_factor_spatial
10561056

10571057
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
10581058
# to perform decoding of a single video latent at a time.
@@ -1145,12 +1145,13 @@ def clear_cache(self):
11451145
def _encode(self, x: torch.Tensor):
11461146
_, _, num_frame, height, width = x.shape
11471147

1148-
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1149-
return self.tiled_encode(x)
1150-
11511148
self.clear_cache()
11521149
if self.config.patch_size is not None:
11531150
x = patchify(x, patch_size=self.config.patch_size)
1151+
1152+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1153+
return self.tiled_encode(x)
1154+
11541155
iter_ = 1 + (num_frame - 1) // 4
11551156
for i in range(iter_):
11561157
self._enc_conv_idx = [0]

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Any, Dict, Optional, Union
16+
from typing import Any, Dict, Optional, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -92,7 +92,7 @@ def pe_selection_index_based_on_dim(self, h, w):
9292

9393
return selected_indices
9494

95-
def forward(self, latent):
95+
def forward(self, latent) -> torch.Tensor:
9696
batch_size, num_channels, height, width = latent.size()
9797
latent = latent.view(
9898
batch_size,
@@ -173,7 +173,7 @@ def forward(
173173
hidden_states: torch.FloatTensor,
174174
temb: torch.FloatTensor,
175175
attention_kwargs: Optional[Dict[str, Any]] = None,
176-
):
176+
) -> torch.Tensor:
177177
residual = hidden_states
178178
attention_kwargs = attention_kwargs or {}
179179

@@ -242,7 +242,7 @@ def forward(
242242
encoder_hidden_states: torch.FloatTensor,
243243
temb: torch.FloatTensor,
244244
attention_kwargs: Optional[Dict[str, Any]] = None,
245-
):
245+
) -> Tuple[torch.Tensor, torch.Tensor]:
246246
residual = hidden_states
247247
residual_context = encoder_hidden_states
248248
attention_kwargs = attention_kwargs or {}
@@ -472,7 +472,7 @@ def forward(
472472
timestep: torch.LongTensor = None,
473473
attention_kwargs: Optional[Dict[str, Any]] = None,
474474
return_dict: bool = True,
475-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
475+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
476476
if attention_kwargs is not None:
477477
attention_kwargs = attention_kwargs.copy()
478478
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def forward(
122122
temb: torch.Tensor,
123123
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
124124
attention_kwargs: Optional[Dict[str, Any]] = None,
125-
) -> torch.Tensor:
125+
) -> Tuple[torch.Tensor, torch.Tensor]:
126126
text_seq_length = encoder_hidden_states.size(1)
127127
attention_kwargs = attention_kwargs or {}
128128

@@ -441,7 +441,7 @@ def forward(
441441
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
442442
attention_kwargs: Optional[Dict[str, Any]] = None,
443443
return_dict: bool = True,
444-
):
444+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
445445
if attention_kwargs is not None:
446446
attention_kwargs = attention_kwargs.copy()
447447
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/consisid_transformer_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def forward(
315315
encoder_hidden_states: torch.Tensor,
316316
temb: torch.Tensor,
317317
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
318-
) -> torch.Tensor:
318+
) -> Tuple[torch.Tensor, torch.Tensor]:
319319
text_seq_length = encoder_hidden_states.size(1)
320320

321321
# norm & modulate
@@ -691,7 +691,7 @@ def forward(
691691
id_cond: Optional[torch.Tensor] = None,
692692
id_vit_hidden: Optional[torch.Tensor] = None,
693693
return_dict: bool = True,
694-
):
694+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
695695
if attention_kwargs is not None:
696696
attention_kwargs = attention_kwargs.copy()
697697
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/lumina_nextdit2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -124,7 +124,7 @@ def forward(
124124
encoder_mask: torch.Tensor,
125125
temb: torch.Tensor,
126126
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
127-
):
127+
) -> torch.Tensor:
128128
"""
129129
Perform a forward pass through the LuminaNextDiTBlock.
130130
@@ -297,7 +297,7 @@ def forward(
297297
image_rotary_emb: torch.Tensor,
298298
cross_attention_kwargs: Dict[str, Any] = None,
299299
return_dict=True,
300-
) -> torch.Tensor:
300+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
301301
"""
302302
Forward pass of LuminaNextDiT.
303303

src/diffusers/models/transformers/transformer_bria.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def forward(
472472
temb: torch.Tensor,
473473
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
474474
attention_kwargs: Optional[Dict[str, Any]] = None,
475-
) -> torch.Tensor:
475+
) -> Tuple[torch.Tensor, torch.Tensor]:
476476
text_seq_len = encoder_hidden_states.shape[1]
477477
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
478478

@@ -588,7 +588,7 @@ def forward(
588588
return_dict: bool = True,
589589
controlnet_block_samples=None,
590590
controlnet_single_block_samples=None,
591-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
591+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
592592
"""
593593
The [`BriaTransformer2DModel`] forward method.
594594

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict, Union
16+
from typing import Dict, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -79,7 +79,7 @@ def forward(
7979
hidden_states: torch.Tensor,
8080
encoder_hidden_states: torch.Tensor,
8181
emb: torch.Tensor,
82-
) -> torch.Tensor:
82+
) -> Tuple[torch.Tensor, torch.Tensor]:
8383
text_seq_length = encoder_hidden_states.size(1)
8484

8585
# norm & modulate
@@ -293,7 +293,7 @@ def forward(
293293
target_size: torch.Tensor,
294294
crop_coords: torch.Tensor,
295295
return_dict: bool = True,
296-
) -> Union[torch.Tensor, Transformer2DModelOutput]:
296+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
297297
"""
298298
The [`CogView3PlusTransformer2DModel`] forward method.
299299

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def forward(
494494
] = None,
495495
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
496496
attention_kwargs: Optional[Dict[str, Any]] = None,
497-
) -> torch.Tensor:
497+
) -> Tuple[torch.Tensor, torch.Tensor]:
498498
# 1. Timestep conditioning
499499
(
500500
norm_hidden_states,
@@ -717,7 +717,7 @@ def forward(
717717
image_rotary_emb: Optional[
718718
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
719719
] = None,
720-
) -> Union[torch.Tensor, Transformer2DModelOutput]:
720+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
721721
if attention_kwargs is not None:
722722
attention_kwargs = attention_kwargs.copy()
723723
lora_scale = attention_kwargs.pop("scale", 1.0)

0 commit comments

Comments
 (0)