Skip to content

Commit f925211

Browse files
authored
enable FSDP for qwen vl (#184)
* enable FSDP for qwen vl * fix
1 parent d7032d0 commit f925211

File tree

12 files changed

+27
-34
lines changed

12 files changed

+27
-34
lines changed

diffsynth_engine/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def unload_loras(self):
5757
def get_tp_plan(self):
5858
raise NotImplementedError(f"{self.__class__.__name__} does not support TP")
5959

60-
def get_fsdp_modules(self):
60+
def get_fsdp_module_cls(self):
6161
raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")
6262

6363

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,5 +515,5 @@ def compile_repeated_blocks(self, *args, **kwargs):
515515
for block in self.blocks:
516516
block.compile(*args, **kwargs)
517517

518-
def get_fsdp_modules(self):
519-
return ["blocks", "single_blocks"]
518+
def get_fsdp_module_cls(self):
519+
return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock}

diffsynth_engine/models/qwen_image/qwen2_5_vl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,8 @@ def _unmask_unattended(
942942

943943

944944
class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
945+
_supports_parallelization = True
946+
945947
def __init__(
946948
self,
947949
vision_config: Qwen2_5_VLVisionConfig,
@@ -1173,6 +1175,9 @@ def get_rope_index(
11731175

11741176
return position_ids, mrope_position_deltas
11751177

1178+
def get_fsdp_module_cls(self):
1179+
return {Qwen2_5_VisionBlock, Qwen2_5_VLDecoderLayer}
1180+
11761181
def forward(
11771182
self,
11781183
input_ids: Optional[torch.LongTensor] = None,

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,5 +544,5 @@ def compile_repeated_blocks(self, *args, **kwargs):
544544
for block in self.transformer_blocks:
545545
block.compile(*args, **kwargs)
546546

547-
def get_fsdp_modules(self):
548-
return ["transformer_blocks"]
547+
def get_fsdp_module_cls(self):
548+
return {QwenImageTransformerBlock}

diffsynth_engine/models/wan/wan_audio_encoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
223223

224224
class Wav2Vec2Model(PreTrainedModel):
225225
converter = Wav2Vec2StateDictConverter()
226-
_supports_parallelization = False
227226

228227
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
229228
super().__init__()

diffsynth_engine/models/wan/wan_dit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,5 +502,5 @@ def compile_repeated_blocks(self, *args, **kwargs):
502502
for block in self.single_blocks:
503503
block.compile(*args, **kwargs)
504504

505-
def get_fsdp_modules(self):
506-
return ["blocks"]
505+
def get_fsdp_module_cls(self):
506+
return {DiTBlock}

diffsynth_engine/pipelines/flux_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
143143
layer_id, layer_type = name.split("_", 1)
144144
layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
145145
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])
146-
146+
147147
lora_args = {}
148148
lora_args["alpha"] = param
149149
lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
@@ -517,7 +517,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
517517
if config.use_fbcache:
518518
dit = FluxDiTFBCache.from_state_dict(
519519
state_dicts.model,
520-
device=init_device,
520+
device=("cpu" if config.use_fsdp else init_device),
521521
dtype=config.model_dtype,
522522
in_channel=config.control_type.get_in_channel(),
523523
attn_kwargs=attn_kwargs,
@@ -526,7 +526,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
526526
else:
527527
dit = FluxDiT.from_state_dict(
528528
state_dicts.model,
529-
device=init_device,
529+
device=("cpu" if config.use_fsdp else init_device),
530530
dtype=config.model_dtype,
531531
in_channel=config.control_type.get_in_channel(),
532532
attn_kwargs=attn_kwargs,

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
237237
state_dicts.encoder,
238238
vision_config=vision_config,
239239
config=text_config,
240-
device=init_device,
240+
device=("cpu" if config.use_fsdp else init_device),
241241
dtype=config.encoder_dtype,
242242
)
243243
with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
@@ -257,15 +257,15 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
257257
if config.use_fbcache:
258258
dit = QwenImageDiTFBCache.from_state_dict(
259259
state_dicts.model,
260-
device=init_device,
260+
device=("cpu" if config.use_fsdp else init_device),
261261
dtype=config.model_dtype,
262262
attn_kwargs=attn_kwargs,
263263
relative_l1_threshold=config.fbcache_relative_l1_threshold,
264264
)
265265
else:
266266
dit = QwenImageDiT.from_state_dict(
267267
state_dicts.model,
268-
device=init_device,
268+
device=("cpu" if config.use_fsdp else init_device),
269269
dtype=config.model_dtype,
270270
attn_kwargs=attn_kwargs,
271271
)

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def from_pretrained(cls, model_path_or_config: SDXLPipelineConfig) -> "SDXLImage
181181

182182
@classmethod
183183
def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
184-
init_device = "cpu" if config.offload_mode else config.device
184+
init_device = "cpu" if config.offload_mode is not None else config.device
185185
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
186186
tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
187187
with LoRAContext():

diffsynth_engine/pipelines/wan_s2v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def _from_state_dict(
664664
dit = WanS2VDiT.from_state_dict(
665665
state_dicts.model,
666666
config=model_config,
667-
device=init_device,
667+
device=("cpu" if config.use_fsdp else init_device),
668668
dtype=config.model_dtype,
669669
attn_kwargs=attn_kwargs,
670670
)

0 commit comments

Comments
 (0)