Skip to content

Commit 558c64e

Browse files
committed
update
1 parent 5f898a1 commit 558c64e

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

src/diffusers/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
5252
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
5353
_import_structure["embeddings"] = ["ImageProjection"]
54+
_import_structure["layerwise_upcasting_utils"] = [
55+
"LayerwiseUpcastingGranularity",
56+
"apply_layerwise_upcasting",
57+
"apply_layerwise_upcasting_hook",
58+
]
5459
_import_structure["modeling_utils"] = ["ModelMixin"]
5560
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
5661
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def enable_layerwise_upcasting(
321321
storage_dtype: torch.dtype = torch.float8_e4m3fn,
322322
compute_dtype: Optional[torch.dtype] = None,
323323
granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER,
324+
skip_modules_pattern: Optional[List[str]] = None,
324325
) -> None:
325326
r"""
326327
Activates layerwise upcasting for the current model.
@@ -364,7 +365,8 @@ def enable_layerwise_upcasting(
364365
[`~LayerwiseUpcastingGranularity`] for more information.
365366
"""
366367

367-
skip_modules_pattern = []
368+
if skip_modules_pattern is None:
369+
skip_modules_pattern = []
368370
if self._keep_in_fp32_modules is not None:
369371
skip_modules_pattern.extend(self._keep_in_fp32_modules)
370372
if self._always_upcast_modules is not None:

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def __call__(
836836
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
837837
progress_bar.update()
838838

839-
if not output_type == "latents":
839+
if not output_type == "latent":
840840
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
841841
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
842842
else:

0 commit comments

Comments
 (0)