Skip to content

Commit f5b69b0

Browse files
committed
resolve conflicts.
2 parents 41ea4c8 + 3be6706 commit f5b69b0

File tree

19 files changed

+1945
-93
lines changed

19 files changed

+1945
-93
lines changed

.github/workflows/pr_tests_gpu.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,51 @@ env:
2828
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
2929

3030
jobs:
31+
check_code_quality:
32+
runs-on: ubuntu-22.04
33+
steps:
34+
- uses: actions/checkout@v3
35+
- name: Set up Python
36+
uses: actions/setup-python@v4
37+
with:
38+
python-version: "3.8"
39+
- name: Install dependencies
40+
run: |
41+
python -m pip install --upgrade pip
42+
pip install .[quality]
43+
- name: Check quality
44+
run: make quality
45+
- name: Check if failure
46+
if: ${{ failure() }}
47+
run: |
48+
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
49+
50+
check_repository_consistency:
51+
needs: check_code_quality
52+
runs-on: ubuntu-22.04
53+
steps:
54+
- uses: actions/checkout@v3
55+
- name: Set up Python
56+
uses: actions/setup-python@v4
57+
with:
58+
python-version: "3.8"
59+
- name: Install dependencies
60+
run: |
61+
python -m pip install --upgrade pip
62+
pip install .[quality]
63+
- name: Check repo consistency
64+
run: |
65+
python utils/check_copies.py
66+
python utils/check_dummies.py
67+
python utils/check_support_list.py
68+
make deps_table_check_updated
69+
- name: Check if failure
70+
if: ${{ failure() }}
71+
run: |
72+
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
73+
3174
setup_torch_cuda_pipeline_matrix:
75+
needs: [check_code_quality, check_repository_consistency]
3276
name: Setup Torch Pipelines CUDA Slow Tests Matrix
3377
runs-on:
3478
group: aws-general-8-plus

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
196196
- all
197197
- __call__
198198

199+
## LTXConditionPipeline
200+
201+
[[autodoc]] LTXConditionPipeline
202+
- all
203+
- __call__
204+
199205
## LTXPipelineOutput
200206

201207
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

examples/research_projects/pytorch_xla/inference/flux/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Generating images using Flux and PyTorch/XLA
22

3-
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
4-
5-
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
3+
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
64

75
## Create TPU
86

@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
2321
python3 -c "import torch; import torch_xla;"
2422
```
2523

26-
Install dependencies
24+
Clone the diffusers repo and install dependencies
2725

2826
```bash
27+
git clone https://github.com/huggingface/diffusers.git
28+
cd diffusers
2929
pip install transformers accelerate sentencepiece structlog
30-
pushd ../../..
3130
pip install .
32-
popd
31+
cd examples/research_projects/pytorch_xla/inference/flux/
3332
```
3433

3534
## Run the inference job
3635

3736
### Authenticate
3837

39-
Run the following command to authenticate your token in order to download Flux weights.
38+
**Gated Model**
39+
40+
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
4041

4142
```bash
4243
huggingface-cli login

scripts/convert_ltx_to_diffusers.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,39 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474
"last_scale_shift_table": "scale_shift_table",
7575
}
7676

77+
VAE_095_RENAME_DICT = {
78+
# decoder
79+
"up_blocks.0": "mid_block",
80+
"up_blocks.1": "up_blocks.0.upsamplers.0",
81+
"up_blocks.2": "up_blocks.0",
82+
"up_blocks.3": "up_blocks.1.upsamplers.0",
83+
"up_blocks.4": "up_blocks.1",
84+
"up_blocks.5": "up_blocks.2.upsamplers.0",
85+
"up_blocks.6": "up_blocks.2",
86+
"up_blocks.7": "up_blocks.3.upsamplers.0",
87+
"up_blocks.8": "up_blocks.3",
88+
# encoder
89+
"down_blocks.0": "down_blocks.0",
90+
"down_blocks.1": "down_blocks.0.downsamplers.0",
91+
"down_blocks.2": "down_blocks.1",
92+
"down_blocks.3": "down_blocks.1.downsamplers.0",
93+
"down_blocks.4": "down_blocks.2",
94+
"down_blocks.5": "down_blocks.2.downsamplers.0",
95+
"down_blocks.6": "down_blocks.3",
96+
"down_blocks.7": "down_blocks.3.downsamplers.0",
97+
"down_blocks.8": "mid_block",
98+
# common
99+
"last_time_embedder": "time_embedder",
100+
"last_scale_shift_table": "scale_shift_table",
101+
}
102+
77103
VAE_SPECIAL_KEYS_REMAP = {
78104
"per_channel_statistics.channel": remove_keys_,
79105
"per_channel_statistics.mean-of-means": remove_keys_,
80106
"per_channel_statistics.mean-of-stds": remove_keys_,
81107
"model.diffusion_model": remove_keys_,
82108
}
83109

84-
VAE_091_SPECIAL_KEYS_REMAP = {
85-
"timestep_scale_multiplier": remove_keys_,
86-
}
87-
88110

89111
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
90112
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
104126
def convert_transformer(
105127
ckpt_path: str,
106128
dtype: torch.dtype,
129+
version: str = "0.9.0",
107130
):
108131
PREFIX_KEY = "model.diffusion_model."
109132

110133
original_state_dict = get_state_dict(load_file(ckpt_path))
134+
config = {}
135+
if version == "0.9.5":
136+
config["_use_causal_rope_fix"] = True
111137
with init_empty_weights():
112-
transformer = LTXVideoTransformer3DModel()
138+
transformer = LTXVideoTransformer3DModel(**config)
113139

114140
for key in list(original_state_dict.keys()):
115141
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161187
"out_channels": 3,
162188
"latent_channels": 128,
163189
"block_out_channels": (128, 256, 512, 512),
190+
"down_block_types": (
191+
"LTXVideoDownBlock3D",
192+
"LTXVideoDownBlock3D",
193+
"LTXVideoDownBlock3D",
194+
"LTXVideoDownBlock3D",
195+
),
164196
"decoder_block_out_channels": (128, 256, 512, 512),
165197
"layers_per_block": (4, 3, 3, 3, 4),
166198
"decoder_layers_per_block": (4, 3, 3, 3, 4),
167199
"spatio_temporal_scaling": (True, True, True, False),
168200
"decoder_spatio_temporal_scaling": (True, True, True, False),
169201
"decoder_inject_noise": (False, False, False, False, False),
202+
"downsample_type": ("conv", "conv", "conv", "conv"),
170203
"upsample_residual": (False, False, False, False),
171204
"upsample_factor": (1, 1, 1, 1),
172205
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183216
"out_channels": 3,
184217
"latent_channels": 128,
185218
"block_out_channels": (128, 256, 512, 512),
219+
"down_block_types": (
220+
"LTXVideoDownBlock3D",
221+
"LTXVideoDownBlock3D",
222+
"LTXVideoDownBlock3D",
223+
"LTXVideoDownBlock3D",
224+
),
186225
"decoder_block_out_channels": (256, 512, 1024),
187226
"layers_per_block": (4, 3, 3, 3, 4),
188227
"decoder_layers_per_block": (5, 6, 7, 8),
189228
"spatio_temporal_scaling": (True, True, True, False),
190229
"decoder_spatio_temporal_scaling": (True, True, True),
191230
"decoder_inject_noise": (True, True, True, False),
231+
"downsample_type": ("conv", "conv", "conv", "conv"),
192232
"upsample_residual": (True, True, True),
193233
"upsample_factor": (2, 2, 2),
194234
"timestep_conditioning": True,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
200240
"decoder_causal": False,
201241
}
202242
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
203-
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
243+
elif version == "0.9.5":
244+
config = {
245+
"in_channels": 3,
246+
"out_channels": 3,
247+
"latent_channels": 128,
248+
"block_out_channels": (128, 256, 512, 1024, 2048),
249+
"down_block_types": (
250+
"LTXVideo095DownBlock3D",
251+
"LTXVideo095DownBlock3D",
252+
"LTXVideo095DownBlock3D",
253+
"LTXVideo095DownBlock3D",
254+
),
255+
"decoder_block_out_channels": (256, 512, 1024),
256+
"layers_per_block": (4, 6, 6, 2, 2),
257+
"decoder_layers_per_block": (5, 5, 5, 5),
258+
"spatio_temporal_scaling": (True, True, True, True),
259+
"decoder_spatio_temporal_scaling": (True, True, True),
260+
"decoder_inject_noise": (False, False, False, False),
261+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
262+
"upsample_residual": (True, True, True),
263+
"upsample_factor": (2, 2, 2),
264+
"timestep_conditioning": True,
265+
"patch_size": 4,
266+
"patch_size_t": 1,
267+
"resnet_norm_eps": 1e-6,
268+
"scaling_factor": 1.0,
269+
"encoder_causal": True,
270+
"decoder_causal": False,
271+
"spatial_compression_ratio": 32,
272+
"temporal_compression_ratio": 8,
273+
}
274+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
204275
return config
205276

206277

@@ -223,7 +294,7 @@ def get_args():
223294
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
224295
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
225296
parser.add_argument(
226-
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
297+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
227298
)
228299
return parser.parse_args()
229300

@@ -277,14 +348,17 @@ def get_args():
277348
for param in text_encoder.parameters():
278349
param.data = param.data.contiguous()
279350

280-
scheduler = FlowMatchEulerDiscreteScheduler(
281-
use_dynamic_shifting=True,
282-
base_shift=0.95,
283-
max_shift=2.05,
284-
base_image_seq_len=1024,
285-
max_image_seq_len=4096,
286-
shift_terminal=0.1,
287-
)
351+
if args.version == "0.9.5":
352+
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
353+
else:
354+
scheduler = FlowMatchEulerDiscreteScheduler(
355+
use_dynamic_shifting=True,
356+
base_shift=0.95,
357+
max_shift=2.05,
358+
base_image_seq_len=1024,
359+
max_image_seq_len=4096,
360+
shift_terminal=0.1,
361+
)
288362

289363
pipe = LTXPipeline(
290364
scheduler=scheduler,

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@
402402
"LDMTextToImagePipeline",
403403
"LEditsPPPipelineStableDiffusion",
404404
"LEditsPPPipelineStableDiffusionXL",
405+
"LTXConditionPipeline",
405406
"LTXImageToVideoPipeline",
406407
"LTXPipeline",
407408
"Lumina2Pipeline",
@@ -947,6 +948,7 @@
947948
LDMTextToImagePipeline,
948949
LEditsPPPipelineStableDiffusion,
949950
LEditsPPPipelineStableDiffusionXL,
951+
LTXConditionPipeline,
950952
LTXImageToVideoPipeline,
951953
LTXPipeline,
952954
Lumina2Pipeline,

src/diffusers/hooks/group_offloading.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ def onload_(self):
9898
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
9999
if self.record_stream:
100100
buffer.data.record_stream(current_stream)
101-
101+
102102
if self.parameters is not None:
103103
for param in self.parameters:
104104
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
105105
if self.record_stream:
106106
param.data.record_stream(current_stream)
107+
107108
if self.buffers is not None:
108109
for buffer in self.buffers:
109110
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
@@ -198,6 +199,13 @@ def __init__(self):
198199
self._layer_execution_tracker_module_names = set()
199200

200201
def initialize_hook(self, module):
202+
def make_execution_order_update_callback(current_name, current_submodule):
203+
def callback():
204+
logger.debug(f"Adding {current_name} to the execution order")
205+
self.execution_order.append((current_name, current_submodule))
206+
207+
return callback
208+
201209
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
202210
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
203211
# layers are executed during the forward pass.
@@ -209,14 +217,8 @@ def initialize_hook(self, module):
209217
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
210218

211219
if group_offloading_hook is not None:
212-
213-
def make_execution_order_update_callback(current_name, current_submodule):
214-
def callback():
215-
logger.debug(f"Adding {current_name} to the execution order")
216-
self.execution_order.append((current_name, current_submodule))
217-
218-
return callback
219-
220+
# For the first forward pass, we have to load in a blocking manner
221+
group_offloading_hook.group.non_blocking = False
220222
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
221223
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
222224
self._layer_execution_tracker_module_names.add(name)
@@ -246,15 +248,21 @@ def post_forward(self, module, output):
246248
# Remove the layer execution tracker hooks from the submodules
247249
base_module_registry = module._diffusers_hook
248250
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
251+
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
249252

250253
for i in range(num_executed):
251254
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
252255

253256
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
254257
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
255258

256-
# Apply lazy prefetching by setting required attributes
257-
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
259+
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
260+
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
261+
# see the benefits of prefetching.
262+
for hook in group_offloading_hooks:
263+
hook.group.non_blocking = True
264+
265+
# Set required attributes for prefetching
258266
if num_executed > 0:
259267
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
260268
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group

0 commit comments

Comments
 (0)