Skip to content

Commit df5714a

Browse files
authored
Merge branch 'main' into flux-remote-encode
2 parents 77772ef + cb1b8b2 commit df5714a

File tree

19 files changed

+1949
-89
lines changed

19 files changed

+1949
-89
lines changed

.github/workflows/pr_tests_gpu.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,51 @@ env:
2828
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
2929

3030
jobs:
31+
check_code_quality:
32+
runs-on: ubuntu-22.04
33+
steps:
34+
- uses: actions/checkout@v3
35+
- name: Set up Python
36+
uses: actions/setup-python@v4
37+
with:
38+
python-version: "3.8"
39+
- name: Install dependencies
40+
run: |
41+
python -m pip install --upgrade pip
42+
pip install .[quality]
43+
- name: Check quality
44+
run: make quality
45+
- name: Check if failure
46+
if: ${{ failure() }}
47+
run: |
48+
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
49+
50+
check_repository_consistency:
51+
needs: check_code_quality
52+
runs-on: ubuntu-22.04
53+
steps:
54+
- uses: actions/checkout@v3
55+
- name: Set up Python
56+
uses: actions/setup-python@v4
57+
with:
58+
python-version: "3.8"
59+
- name: Install dependencies
60+
run: |
61+
python -m pip install --upgrade pip
62+
pip install .[quality]
63+
- name: Check repo consistency
64+
run: |
65+
python utils/check_copies.py
66+
python utils/check_dummies.py
67+
python utils/check_support_list.py
68+
make deps_table_check_updated
69+
- name: Check if failure
70+
if: ${{ failure() }}
71+
run: |
72+
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
73+
3174
setup_torch_cuda_pipeline_matrix:
75+
needs: [check_code_quality, check_repository_consistency]
3276
name: Setup Torch Pipelines CUDA Slow Tests Matrix
3377
runs-on:
3478
group: aws-general-8-plus

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
196196
- all
197197
- __call__
198198

199+
## LTXConditionPipeline
200+
201+
[[autodoc]] LTXConditionPipeline
202+
- all
203+
- __call__
204+
199205
## LTXPipelineOutput
200206

201207
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

examples/research_projects/pytorch_xla/inference/flux/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Generating images using Flux and PyTorch/XLA
22

3-
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
4-
5-
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
3+
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
64

75
## Create TPU
86

@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
2321
python3 -c "import torch; import torch_xla;"
2422
```
2523

26-
Install dependencies
24+
Clone the diffusers repo and install dependencies
2725

2826
```bash
27+
git clone https://github.com/huggingface/diffusers.git
28+
cd diffusers
2929
pip install transformers accelerate sentencepiece structlog
30-
pushd ../../..
3130
pip install .
32-
popd
31+
cd examples/research_projects/pytorch_xla/inference/flux/
3332
```
3433

3534
## Run the inference job
3635

3736
### Authenticate
3837

39-
Run the following command to authenticate your token in order to download Flux weights.
38+
**Gated Model**
39+
40+
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
4041

4142
```bash
4243
huggingface-cli login

scripts/convert_ltx_to_diffusers.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,39 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474
"last_scale_shift_table": "scale_shift_table",
7575
}
7676

77+
VAE_095_RENAME_DICT = {
78+
# decoder
79+
"up_blocks.0": "mid_block",
80+
"up_blocks.1": "up_blocks.0.upsamplers.0",
81+
"up_blocks.2": "up_blocks.0",
82+
"up_blocks.3": "up_blocks.1.upsamplers.0",
83+
"up_blocks.4": "up_blocks.1",
84+
"up_blocks.5": "up_blocks.2.upsamplers.0",
85+
"up_blocks.6": "up_blocks.2",
86+
"up_blocks.7": "up_blocks.3.upsamplers.0",
87+
"up_blocks.8": "up_blocks.3",
88+
# encoder
89+
"down_blocks.0": "down_blocks.0",
90+
"down_blocks.1": "down_blocks.0.downsamplers.0",
91+
"down_blocks.2": "down_blocks.1",
92+
"down_blocks.3": "down_blocks.1.downsamplers.0",
93+
"down_blocks.4": "down_blocks.2",
94+
"down_blocks.5": "down_blocks.2.downsamplers.0",
95+
"down_blocks.6": "down_blocks.3",
96+
"down_blocks.7": "down_blocks.3.downsamplers.0",
97+
"down_blocks.8": "mid_block",
98+
# common
99+
"last_time_embedder": "time_embedder",
100+
"last_scale_shift_table": "scale_shift_table",
101+
}
102+
77103
VAE_SPECIAL_KEYS_REMAP = {
78104
"per_channel_statistics.channel": remove_keys_,
79105
"per_channel_statistics.mean-of-means": remove_keys_,
80106
"per_channel_statistics.mean-of-stds": remove_keys_,
81107
"model.diffusion_model": remove_keys_,
82108
}
83109

84-
VAE_091_SPECIAL_KEYS_REMAP = {
85-
"timestep_scale_multiplier": remove_keys_,
86-
}
87-
88110

89111
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
90112
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
104126
def convert_transformer(
105127
ckpt_path: str,
106128
dtype: torch.dtype,
129+
version: str = "0.9.0",
107130
):
108131
PREFIX_KEY = "model.diffusion_model."
109132

110133
original_state_dict = get_state_dict(load_file(ckpt_path))
134+
config = {}
135+
if version == "0.9.5":
136+
config["_use_causal_rope_fix"] = True
111137
with init_empty_weights():
112-
transformer = LTXVideoTransformer3DModel()
138+
transformer = LTXVideoTransformer3DModel(**config)
113139

114140
for key in list(original_state_dict.keys()):
115141
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161187
"out_channels": 3,
162188
"latent_channels": 128,
163189
"block_out_channels": (128, 256, 512, 512),
190+
"down_block_types": (
191+
"LTXVideoDownBlock3D",
192+
"LTXVideoDownBlock3D",
193+
"LTXVideoDownBlock3D",
194+
"LTXVideoDownBlock3D",
195+
),
164196
"decoder_block_out_channels": (128, 256, 512, 512),
165197
"layers_per_block": (4, 3, 3, 3, 4),
166198
"decoder_layers_per_block": (4, 3, 3, 3, 4),
167199
"spatio_temporal_scaling": (True, True, True, False),
168200
"decoder_spatio_temporal_scaling": (True, True, True, False),
169201
"decoder_inject_noise": (False, False, False, False, False),
202+
"downsample_type": ("conv", "conv", "conv", "conv"),
170203
"upsample_residual": (False, False, False, False),
171204
"upsample_factor": (1, 1, 1, 1),
172205
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183216
"out_channels": 3,
184217
"latent_channels": 128,
185218
"block_out_channels": (128, 256, 512, 512),
219+
"down_block_types": (
220+
"LTXVideoDownBlock3D",
221+
"LTXVideoDownBlock3D",
222+
"LTXVideoDownBlock3D",
223+
"LTXVideoDownBlock3D",
224+
),
186225
"decoder_block_out_channels": (256, 512, 1024),
187226
"layers_per_block": (4, 3, 3, 3, 4),
188227
"decoder_layers_per_block": (5, 6, 7, 8),
189228
"spatio_temporal_scaling": (True, True, True, False),
190229
"decoder_spatio_temporal_scaling": (True, True, True),
191230
"decoder_inject_noise": (True, True, True, False),
231+
"downsample_type": ("conv", "conv", "conv", "conv"),
192232
"upsample_residual": (True, True, True),
193233
"upsample_factor": (2, 2, 2),
194234
"timestep_conditioning": True,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
200240
"decoder_causal": False,
201241
}
202242
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
203-
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
243+
elif version == "0.9.5":
244+
config = {
245+
"in_channels": 3,
246+
"out_channels": 3,
247+
"latent_channels": 128,
248+
"block_out_channels": (128, 256, 512, 1024, 2048),
249+
"down_block_types": (
250+
"LTXVideo095DownBlock3D",
251+
"LTXVideo095DownBlock3D",
252+
"LTXVideo095DownBlock3D",
253+
"LTXVideo095DownBlock3D",
254+
),
255+
"decoder_block_out_channels": (256, 512, 1024),
256+
"layers_per_block": (4, 6, 6, 2, 2),
257+
"decoder_layers_per_block": (5, 5, 5, 5),
258+
"spatio_temporal_scaling": (True, True, True, True),
259+
"decoder_spatio_temporal_scaling": (True, True, True),
260+
"decoder_inject_noise": (False, False, False, False),
261+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
262+
"upsample_residual": (True, True, True),
263+
"upsample_factor": (2, 2, 2),
264+
"timestep_conditioning": True,
265+
"patch_size": 4,
266+
"patch_size_t": 1,
267+
"resnet_norm_eps": 1e-6,
268+
"scaling_factor": 1.0,
269+
"encoder_causal": True,
270+
"decoder_causal": False,
271+
"spatial_compression_ratio": 32,
272+
"temporal_compression_ratio": 8,
273+
}
274+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
204275
return config
205276

206277

@@ -223,7 +294,7 @@ def get_args():
223294
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
224295
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
225296
parser.add_argument(
226-
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
297+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
227298
)
228299
return parser.parse_args()
229300

@@ -277,14 +348,17 @@ def get_args():
277348
for param in text_encoder.parameters():
278349
param.data = param.data.contiguous()
279350

280-
scheduler = FlowMatchEulerDiscreteScheduler(
281-
use_dynamic_shifting=True,
282-
base_shift=0.95,
283-
max_shift=2.05,
284-
base_image_seq_len=1024,
285-
max_image_seq_len=4096,
286-
shift_terminal=0.1,
287-
)
351+
if args.version == "0.9.5":
352+
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
353+
else:
354+
scheduler = FlowMatchEulerDiscreteScheduler(
355+
use_dynamic_shifting=True,
356+
base_shift=0.95,
357+
max_shift=2.05,
358+
base_image_seq_len=1024,
359+
max_image_seq_len=4096,
360+
shift_terminal=0.1,
361+
)
288362

289363
pipe = LTXPipeline(
290364
scheduler=scheduler,

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@
402402
"LDMTextToImagePipeline",
403403
"LEditsPPPipelineStableDiffusion",
404404
"LEditsPPPipelineStableDiffusionXL",
405+
"LTXConditionPipeline",
405406
"LTXImageToVideoPipeline",
406407
"LTXPipeline",
407408
"Lumina2Pipeline",
@@ -947,6 +948,7 @@
947948
LDMTextToImagePipeline,
948949
LEditsPPPipelineStableDiffusion,
949950
LEditsPPPipelineStableDiffusionXL,
951+
LTXConditionPipeline,
950952
LTXImageToVideoPipeline,
951953
LTXPipeline,
952954
Lumina2Pipeline,

src/diffusers/hooks/group_offloading.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def onload_(self):
8383

8484
with context:
8585
for group_module in self.modules:
86-
group_module.to(self.onload_device, non_blocking=self.non_blocking)
86+
for param in group_module.parameters():
87+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
88+
for buffer in group_module.buffers():
89+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
8790
if self.parameters is not None:
8891
for param in self.parameters:
8992
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
@@ -98,6 +101,12 @@ def offload_(self):
98101
for group_module in self.modules:
99102
for param in group_module.parameters():
100103
param.data = self.cpu_param_dict[param]
104+
if self.parameters is not None:
105+
for param in self.parameters:
106+
param.data = self.cpu_param_dict[param]
107+
if self.buffers is not None:
108+
for buffer in self.buffers:
109+
buffer.data = self.cpu_param_dict[buffer]
101110
else:
102111
for group_module in self.modules:
103112
group_module.to(self.offload_device, non_blocking=self.non_blocking)
@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
387396
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
388397
cpu_param_dict = None
389398
if stream is not None:
390-
for param in module.parameters():
391-
param.data = param.data.cpu().pin_memory()
392-
cpu_param_dict = {param: param.data for param in module.parameters()}
399+
cpu_param_dict = _get_pinned_cpu_param_dict(module)
393400

394401
# Create module groups for ModuleList and Sequential blocks
395402
modules_with_group_offloading = set()
@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
486493
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
487494
cpu_param_dict = None
488495
if stream is not None:
489-
for param in module.parameters():
490-
param.data = param.data.cpu().pin_memory()
491-
cpu_param_dict = {param: param.data for param in module.parameters()}
496+
cpu_param_dict = _get_pinned_cpu_param_dict(module)
492497

493498
# Create module groups for leaf modules and apply group offloading hooks
494499
modules_with_group_offloading = set()
@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
604609
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
605610

606611

612+
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
613+
cpu_param_dict = {}
614+
for param in module.parameters():
615+
param.data = param.data.cpu().pin_memory()
616+
cpu_param_dict[param] = param.data
617+
for buffer in module.buffers():
618+
buffer.data = buffer.data.cpu().pin_memory()
619+
cpu_param_dict[buffer] = buffer.data
620+
return cpu_param_dict
621+
622+
607623
def _gather_parameters_with_no_group_offloading_parent(
608624
module: torch.nn.Module, modules_with_group_offloading: Set[str]
609625
) -> List[torch.nn.Parameter]:

0 commit comments

Comments
 (0)