Skip to content

Commit db16983

Browse files
committed
add single file to pipelines
1 parent 9f9e016 commit db16983

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,10 +2212,9 @@ def swap_scale_shift(weight):
22122212

22132213

22142214
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2215-
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2216-
2217-
def remove_keys_(key: str, state_dict):
2218-
state_dict.pop(key)
2215+
converted_state_dict = {
2216+
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
2217+
}
22192218

22202219
TRANSFORMER_KEYS_RENAME_DICT = {
22212220
"model.diffusion_model.": "",
@@ -2225,9 +2224,7 @@ def remove_keys_(key: str, state_dict):
22252224
"k_norm": "norm_k",
22262225
}
22272226

2228-
TRANSFORMER_SPECIAL_KEYS_REMAP = {
2229-
"vae": remove_keys_,
2230-
}
2227+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
22312228

22322229
for key in list(converted_state_dict.keys()):
22332230
new_key = key
@@ -2245,7 +2242,7 @@ def remove_keys_(key: str, state_dict):
22452242

22462243

22472244
def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2248-
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2245+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
22492246

22502247
def remove_keys_(key: str, state_dict):
22512248
state_dict.pop(key)
@@ -2287,7 +2284,6 @@ def remove_keys_(key: str, state_dict):
22872284
"per_channel_statistics.channel": remove_keys_,
22882285
"per_channel_statistics.mean-of-means": remove_keys_,
22892286
"per_channel_statistics.mean-of-stds": remove_keys_,
2290-
"model.diffusion_model": remove_keys_,
22912287
}
22922288

22932289
for key in list(converted_state_dict.keys()):

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import T5EncoderModel, T5TokenizerFast
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
23+
from ...loaders import FromSingleFileMixin
2324
from ...models.autoencoders import AutoencoderKLLTX
2425
from ...models.transformers import LTXTransformer3DModel
2526
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -139,7 +140,7 @@ def retrieve_timesteps(
139140
return timesteps, num_inference_steps
140141

141142

142-
class LTXPipeline(DiffusionPipeline):
143+
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin):
143144
r"""
144145
Pipeline for text-to-video generation.
145146

src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...image_processor import PipelineImageInput
24+
from ...loaders import FromSingleFileMixin
2425
from ...models.autoencoders import AutoencoderKLLTX
2526
from ...models.transformers import LTXTransformer3DModel
2627
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -158,7 +159,7 @@ def retrieve_latents(
158159
raise AttributeError("Could not access latents of provided encoder_output")
159160

160161

161-
class LTXImageToVideoPipeline(DiffusionPipeline):
162+
class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin):
162163
r"""
163164
Pipeline for image-to-video generation.
164165

0 commit comments

Comments
 (0)