Skip to content

Commit 3dcb466

Browse files
authored
Patch WanTimeTextImageEmbedding forward only with fp8 (#327)
1 parent b8bf0fc commit 3dcb466

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

finetrainers/patches/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa
1717
if parallel_backend.tensor_parallel_enabled:
1818
patch.patch_apply_rotary_emb_for_tp_compatibility()
1919

20-
if args.model_name == ModelType.WAN:
20+
if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules:
2121
from .models.wan import patch
2222

2323
patch.patch_time_text_image_embedding_forward()

0 commit comments

Comments
 (0)