Skip to content

Commit 9aae75c

Browse files
authored
[Feat] support HSDP for Flux family (vllm-project#1900)
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 284575a commit 9aae75c

File tree

7 files changed

+37
-20
lines changed

7 files changed

+37
-20
lines changed

docs/user_guide/diffusion/parallelism_acceleration.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@ The following table shows which models are currently supported by parallelism me
2424

2525
### ImageGen
2626

27-
| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | Expert-Parallel |
28-
|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:|:------------------:|:---------------:|
29-
| **LongCat-Image** | `meituan-longcat/LongCat-Image` |||||| N/A |
30-
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` |||||| N/A |
31-
| **Ovis-Image** | `OvisAI/Ovis-Image` |||||| N/A |
32-
| **Qwen-Image** | `Qwen/Qwen-Image` |||||| N/A |
33-
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` |||||| N/A |
34-
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` |||||| N/A |
35-
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` |||||| N/A |
36-
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` |||| ✅ (TP=2 only) || N/A |
37-
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` |||||| N/A |
38-
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` |||||| N/A |
39-
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` |||||| N/A |
40-
| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` |||||| N/A |
41-
| **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` |||||| |
42-
| **DreamID-Omni** | `XuGuo699/DreamID-Omni` |||||| N/A |
27+
| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | Expert-Parallel | HSDP |
28+
|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:|:------------------:|:---------------:|:----:|
29+
| **LongCat-Image** | `meituan-longcat/LongCat-Image` |||||| N/A | |
30+
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` |||||| N/A | |
31+
| **Ovis-Image** | `OvisAI/Ovis-Image` |||||| N/A | |
32+
| **Qwen-Image** | `Qwen/Qwen-Image` |||||| N/A | |
33+
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` |||||| N/A | |
34+
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` |||||| N/A | |
35+
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` |||||| N/A | |
36+
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` |||| ✅ (TP=2 only) || N/A | |
37+
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` |||||| N/A | |
38+
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` |||||| N/A | |
39+
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` |||||| N/A | |
40+
| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` |||||| N/A | |
41+
| **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` |||||| | |
42+
| **DreamID-Omni** | `XuGuo699/DreamID-Omni` |||||| N/A | |
4343

4444
!!! note "TP Limitations for Diffusion Models"
4545
We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.

vllm_omni/diffusion/models/flux/flux_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,12 @@ class FluxTransformer2DModel(nn.Module):
470470
# -- typically a transformer layer
471471
# used for torch compile optimizations
472472
_repeated_blocks = ["FluxTransformerBlock"]
473+
474+
@staticmethod
475+
def _is_transformer_block(name: str, module) -> bool:
476+
return ("transformer_blocks" in name or "single_transformer_blocks" in name) and name.split(".")[-1].isdigit()
477+
478+
_hsdp_shard_conditions = [_is_transformer_block]
473479
packed_modules_mapping = {
474480
"to_qkv": ["to_q", "to_k", "to_v"],
475481
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],

vllm_omni/diffusion/models/flux/pipeline_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def __init__(
160160
)
161161
self.text_encoder = CLIPTextModel.from_pretrained(
162162
model, subfolder="text_encoder", local_files_only=local_files_only
163-
)
163+
).to(self.device)
164164
self.text_encoder_2 = T5EncoderModel.from_pretrained(
165165
model, subfolder="text_encoder_2", local_files_only=local_files_only
166-
)
166+
).to(self.device)
167167
self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
168168
self.device
169169
)

vllm_omni/diffusion/models/flux2/flux2_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,12 @@ class Flux2Transformer2DModel(nn.Module):
553553

554554
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
555555

556+
@staticmethod
557+
def _is_transformer_block(name: str, module) -> bool:
558+
return ("transformer_blocks" in name or "single_transformer_blocks" in name) and name.split(".")[-1].isdigit()
559+
560+
_hsdp_shard_conditions = [_is_transformer_block]
561+
556562
def __init__(
557563
self,
558564
patch_size: int = 1,

vllm_omni/diffusion/models/flux2/pipeline_flux2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def __init__(
366366
)
367367
self.text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
368368
model, subfolder="text_encoder", local_files_only=local_files_only
369-
)
369+
).to(self._execution_device)
370370
self.tokenizer = PixtralProcessor.from_pretrained(
371371
model, subfolder="tokenizer", local_files_only=local_files_only
372372
)

vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,11 @@ class Flux2Transformer2DModel(nn.Module):
741741

742742
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
743743

744+
@staticmethod
745+
def _is_transformer_block(name: str, module) -> bool:
746+
return ("transformer_blocks" in name or "single_transformer_blocks" in name) and name.split(".")[-1].isdigit()
747+
748+
_hsdp_shard_conditions = [_is_transformer_block]
744749
_sp_plan = {
745750
"": {
746751
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True),

vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(
218218
model,
219219
subfolder="text_encoder",
220220
local_files_only=local_files_only,
221-
)
221+
).to(self._execution_device)
222222
self.tokenizer = Qwen2TokenizerFast.from_pretrained(
223223
model,
224224
subfolder="tokenizer",

0 commit comments

Comments
 (0)