Skip to content

Commit c577fcf

Browse files
authored
Merge branch 'main' into main
2 parents ae78a65 + 56f7400 commit c577fcf

File tree

28 files changed

+2103
-154
lines changed

28 files changed

+2103
-154
lines changed

.github/workflows/pr_tests_gpu.yml

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,51 @@ env:
2828
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
2929

3030
jobs:
31+
check_code_quality:
32+
runs-on: ubuntu-22.04
33+
steps:
34+
- uses: actions/checkout@v3
35+
- name: Set up Python
36+
uses: actions/setup-python@v4
37+
with:
38+
python-version: "3.8"
39+
- name: Install dependencies
40+
run: |
41+
python -m pip install --upgrade pip
42+
pip install .[quality]
43+
- name: Check quality
44+
run: make quality
45+
- name: Check if failure
46+
if: ${{ failure() }}
47+
run: |
48+
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
49+
50+
check_repository_consistency:
51+
needs: check_code_quality
52+
runs-on: ubuntu-22.04
53+
steps:
54+
- uses: actions/checkout@v3
55+
- name: Set up Python
56+
uses: actions/setup-python@v4
57+
with:
58+
python-version: "3.8"
59+
- name: Install dependencies
60+
run: |
61+
python -m pip install --upgrade pip
62+
pip install .[quality]
63+
- name: Check repo consistency
64+
run: |
65+
python utils/check_copies.py
66+
python utils/check_dummies.py
67+
python utils/check_support_list.py
68+
make deps_table_check_updated
69+
- name: Check if failure
70+
if: ${{ failure() }}
71+
run: |
72+
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
73+
3174
setup_torch_cuda_pipeline_matrix:
75+
needs: [check_code_quality, check_repository_consistency]
3276
name: Setup Torch Pipelines CUDA Slow Tests Matrix
3377
runs-on:
3478
group: aws-general-8-plus
@@ -133,6 +177,7 @@ jobs:
133177

134178
torch_cuda_tests:
135179
name: Torch CUDA Tests
180+
needs: [check_code_quality, check_repository_consistency]
136181
runs-on:
137182
group: aws-g4dn-2xlarge
138183
container:
@@ -201,7 +246,7 @@ jobs:
201246

202247
run_examples_tests:
203248
name: Examples PyTorch CUDA tests on Ubuntu
204-
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
249+
needs: [check_code_quality, check_repository_consistency]
205250
runs-on:
206251
group: aws-g4dn-2xlarge
207252

@@ -220,6 +265,7 @@ jobs:
220265
- name: Install dependencies
221266
run: |
222267
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
268+
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
223269
python -m uv pip install -e [quality,test,training]
224270
225271
- name: Environment

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
196196
- all
197197
- __call__
198198

199+
## LTXConditionPipeline
200+
201+
[[autodoc]] LTXConditionPipeline
202+
- all
203+
- __call__
204+
199205
## LTXPipelineOutput
200206

201207
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

examples/research_projects/autoencoderkl/train_autoencoderkl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def main(args):
627627
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
628628
perceptual_loss = lpips.LPIPS(net="vgg").eval()
629629
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
630+
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
630631

631632
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
632633
def unwrap_model(model):
@@ -951,13 +952,20 @@ def load_model_hook(models, input_dir):
951952
logits_fake = discriminator(reconstructions)
952953
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
953954
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
954-
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
955+
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
955956
logs = {
956-
"disc_loss": disc_loss.detach().mean().item(),
957+
"disc_loss": d_loss.detach().mean().item(),
957958
"logits_real": logits_real.detach().mean().item(),
958959
"logits_fake": logits_fake.detach().mean().item(),
959960
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
960961
}
962+
accelerator.backward(d_loss)
963+
if accelerator.sync_gradients:
964+
params_to_clip = discriminator.parameters()
965+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
966+
disc_optimizer.step()
967+
disc_lr_scheduler.step()
968+
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
961969
# Checks if the accelerator has performed an optimization step behind the scenes
962970
if accelerator.sync_gradients:
963971
progress_bar.update(1)

examples/research_projects/pytorch_xla/inference/flux/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Generating images using Flux and PyTorch/XLA
22

3-
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
4-
5-
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
3+
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
64

75
## Create TPU
86

@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
2321
python3 -c "import torch; import torch_xla;"
2422
```
2523

26-
Install dependencies
24+
Clone the diffusers repo and install dependencies
2725

2826
```bash
27+
git clone https://github.com/huggingface/diffusers.git
28+
cd diffusers
2929
pip install transformers accelerate sentencepiece structlog
30-
pushd ../../..
3130
pip install .
32-
popd
31+
cd examples/research_projects/pytorch_xla/inference/flux/
3332
```
3433

3534
## Run the inference job
3635

3736
### Authenticate
3837

39-
Run the following command to authenticate your token in order to download Flux weights.
38+
**Gated Model**
39+
40+
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
4041

4142
```bash
4243
huggingface-cli login

scripts/convert_ltx_to_diffusers.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,39 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474
"last_scale_shift_table": "scale_shift_table",
7575
}
7676

77+
VAE_095_RENAME_DICT = {
78+
# decoder
79+
"up_blocks.0": "mid_block",
80+
"up_blocks.1": "up_blocks.0.upsamplers.0",
81+
"up_blocks.2": "up_blocks.0",
82+
"up_blocks.3": "up_blocks.1.upsamplers.0",
83+
"up_blocks.4": "up_blocks.1",
84+
"up_blocks.5": "up_blocks.2.upsamplers.0",
85+
"up_blocks.6": "up_blocks.2",
86+
"up_blocks.7": "up_blocks.3.upsamplers.0",
87+
"up_blocks.8": "up_blocks.3",
88+
# encoder
89+
"down_blocks.0": "down_blocks.0",
90+
"down_blocks.1": "down_blocks.0.downsamplers.0",
91+
"down_blocks.2": "down_blocks.1",
92+
"down_blocks.3": "down_blocks.1.downsamplers.0",
93+
"down_blocks.4": "down_blocks.2",
94+
"down_blocks.5": "down_blocks.2.downsamplers.0",
95+
"down_blocks.6": "down_blocks.3",
96+
"down_blocks.7": "down_blocks.3.downsamplers.0",
97+
"down_blocks.8": "mid_block",
98+
# common
99+
"last_time_embedder": "time_embedder",
100+
"last_scale_shift_table": "scale_shift_table",
101+
}
102+
77103
VAE_SPECIAL_KEYS_REMAP = {
78104
"per_channel_statistics.channel": remove_keys_,
79105
"per_channel_statistics.mean-of-means": remove_keys_,
80106
"per_channel_statistics.mean-of-stds": remove_keys_,
81107
"model.diffusion_model": remove_keys_,
82108
}
83109

84-
VAE_091_SPECIAL_KEYS_REMAP = {
85-
"timestep_scale_multiplier": remove_keys_,
86-
}
87-
88110

89111
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
90112
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
104126
def convert_transformer(
105127
ckpt_path: str,
106128
dtype: torch.dtype,
129+
version: str = "0.9.0",
107130
):
108131
PREFIX_KEY = "model.diffusion_model."
109132

110133
original_state_dict = get_state_dict(load_file(ckpt_path))
134+
config = {}
135+
if version == "0.9.5":
136+
config["_use_causal_rope_fix"] = True
111137
with init_empty_weights():
112-
transformer = LTXVideoTransformer3DModel()
138+
transformer = LTXVideoTransformer3DModel(**config)
113139

114140
for key in list(original_state_dict.keys()):
115141
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161187
"out_channels": 3,
162188
"latent_channels": 128,
163189
"block_out_channels": (128, 256, 512, 512),
190+
"down_block_types": (
191+
"LTXVideoDownBlock3D",
192+
"LTXVideoDownBlock3D",
193+
"LTXVideoDownBlock3D",
194+
"LTXVideoDownBlock3D",
195+
),
164196
"decoder_block_out_channels": (128, 256, 512, 512),
165197
"layers_per_block": (4, 3, 3, 3, 4),
166198
"decoder_layers_per_block": (4, 3, 3, 3, 4),
167199
"spatio_temporal_scaling": (True, True, True, False),
168200
"decoder_spatio_temporal_scaling": (True, True, True, False),
169201
"decoder_inject_noise": (False, False, False, False, False),
202+
"downsample_type": ("conv", "conv", "conv", "conv"),
170203
"upsample_residual": (False, False, False, False),
171204
"upsample_factor": (1, 1, 1, 1),
172205
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183216
"out_channels": 3,
184217
"latent_channels": 128,
185218
"block_out_channels": (128, 256, 512, 512),
219+
"down_block_types": (
220+
"LTXVideoDownBlock3D",
221+
"LTXVideoDownBlock3D",
222+
"LTXVideoDownBlock3D",
223+
"LTXVideoDownBlock3D",
224+
),
186225
"decoder_block_out_channels": (256, 512, 1024),
187226
"layers_per_block": (4, 3, 3, 3, 4),
188227
"decoder_layers_per_block": (5, 6, 7, 8),
189228
"spatio_temporal_scaling": (True, True, True, False),
190229
"decoder_spatio_temporal_scaling": (True, True, True),
191230
"decoder_inject_noise": (True, True, True, False),
231+
"downsample_type": ("conv", "conv", "conv", "conv"),
192232
"upsample_residual": (True, True, True),
193233
"upsample_factor": (2, 2, 2),
194234
"timestep_conditioning": True,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
200240
"decoder_causal": False,
201241
}
202242
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
203-
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
243+
elif version == "0.9.5":
244+
config = {
245+
"in_channels": 3,
246+
"out_channels": 3,
247+
"latent_channels": 128,
248+
"block_out_channels": (128, 256, 512, 1024, 2048),
249+
"down_block_types": (
250+
"LTXVideo095DownBlock3D",
251+
"LTXVideo095DownBlock3D",
252+
"LTXVideo095DownBlock3D",
253+
"LTXVideo095DownBlock3D",
254+
),
255+
"decoder_block_out_channels": (256, 512, 1024),
256+
"layers_per_block": (4, 6, 6, 2, 2),
257+
"decoder_layers_per_block": (5, 5, 5, 5),
258+
"spatio_temporal_scaling": (True, True, True, True),
259+
"decoder_spatio_temporal_scaling": (True, True, True),
260+
"decoder_inject_noise": (False, False, False, False),
261+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
262+
"upsample_residual": (True, True, True),
263+
"upsample_factor": (2, 2, 2),
264+
"timestep_conditioning": True,
265+
"patch_size": 4,
266+
"patch_size_t": 1,
267+
"resnet_norm_eps": 1e-6,
268+
"scaling_factor": 1.0,
269+
"encoder_causal": True,
270+
"decoder_causal": False,
271+
"spatial_compression_ratio": 32,
272+
"temporal_compression_ratio": 8,
273+
}
274+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
204275
return config
205276

206277

@@ -223,7 +294,7 @@ def get_args():
223294
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
224295
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
225296
parser.add_argument(
226-
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
297+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
227298
)
228299
return parser.parse_args()
229300

@@ -277,14 +348,17 @@ def get_args():
277348
for param in text_encoder.parameters():
278349
param.data = param.data.contiguous()
279350

280-
scheduler = FlowMatchEulerDiscreteScheduler(
281-
use_dynamic_shifting=True,
282-
base_shift=0.95,
283-
max_shift=2.05,
284-
base_image_seq_len=1024,
285-
max_image_seq_len=4096,
286-
shift_terminal=0.1,
287-
)
351+
if args.version == "0.9.5":
352+
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
353+
else:
354+
scheduler = FlowMatchEulerDiscreteScheduler(
355+
use_dynamic_shifting=True,
356+
base_shift=0.95,
357+
max_shift=2.05,
358+
base_image_seq_len=1024,
359+
max_image_seq_len=4096,
360+
shift_terminal=0.1,
361+
)
288362

289363
pipe = LTXPipeline(
290364
scheduler=scheduler,

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@
402402
"LDMTextToImagePipeline",
403403
"LEditsPPPipelineStableDiffusion",
404404
"LEditsPPPipelineStableDiffusionXL",
405+
"LTXConditionPipeline",
405406
"LTXImageToVideoPipeline",
406407
"LTXPipeline",
407408
"Lumina2Pipeline",
@@ -947,6 +948,7 @@
947948
LDMTextToImagePipeline,
948949
LEditsPPPipelineStableDiffusion,
949950
LEditsPPPipelineStableDiffusionXL,
951+
LTXConditionPipeline,
950952
LTXImageToVideoPipeline,
951953
LTXPipeline,
952954
Lumina2Pipeline,

0 commit comments

Comments
 (0)