Skip to content

Commit 09e7bd2

Browse files
hkunzhebubbliiiing
andauthored
Support low_cpu_mem_usage=True for the text encoder of Wan2.1 (#146)
--------- Co-authored-by: bubbliiiing <3323290568@qq.com>
1 parent ae4f418 commit 09e7bd2

File tree

18 files changed

+102
-29
lines changed

18 files changed

+102
-29
lines changed

comfyui/wan2_1/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def loadmodel(self, GPU_memory_mode, model, precision, config):
164164
text_encoder = WanT5EncoderModel.from_pretrained(
165165
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
166166
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
167-
).to(weight_dtype)
167+
low_cpu_mem_usage=True,
168+
torch_dtype=weight_dtype,
169+
)
168170
pbar.update(1)
169171

170172
if transformer.config.in_channels != vae.config.latent_channels:

comfyui/wan2_1_fun/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def loadmodel(self, GPU_memory_mode, model_type, model, precision, config):
163163
text_encoder = WanT5EncoderModel.from_pretrained(
164164
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
165165
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
166-
).to(weight_dtype)
166+
low_cpu_mem_usage=True,
167+
torch_dtype=weight_dtype,
168+
)
167169
pbar.update(1)
168170

169171
if transformer.config.in_channels != vae.config.latent_channels:

examples/wan2.1/predict_i2v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@
135135
text_encoder = WanT5EncoderModel.from_pretrained(
136136
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
137137
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
138-
).to(weight_dtype)
138+
low_cpu_mem_usage=True,
139+
torch_dtype=weight_dtype,
140+
)
139141
text_encoder = text_encoder.eval()
140142

141143
# Get Clip Image Encoder

examples/wan2.1/predict_t2v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@
129129
text_encoder = WanT5EncoderModel.from_pretrained(
130130
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
131131
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
132-
).to(weight_dtype)
132+
low_cpu_mem_usage=True,
133+
torch_dtype=weight_dtype,
134+
)
133135

134136
# Get Scheduler
135137
Choosen_Scheduler = scheduler_dict = {

examples/wan2.1_fun/predict_i2v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@
135135
text_encoder = WanT5EncoderModel.from_pretrained(
136136
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
137137
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
138-
).to(weight_dtype)
138+
low_cpu_mem_usage=True,
139+
torch_dtype=weight_dtype,
140+
)
139141
text_encoder = text_encoder.eval()
140142

141143
# Get Clip Image Encoder

examples/wan2.1_fun/predict_t2v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@
130130
text_encoder = WanT5EncoderModel.from_pretrained(
131131
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
132132
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
133-
).to(weight_dtype)
133+
low_cpu_mem_usage=True,
134+
torch_dtype=weight_dtype,
135+
)
134136
text_encoder = text_encoder.eval()
135137

136138
if transformer.config.in_channels != vae.config.latent_channels:

examples/wan2.1_fun/predict_v2v_control.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@
142142
text_encoder = WanT5EncoderModel.from_pretrained(
143143
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
144144
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
145-
).to(weight_dtype)
145+
low_cpu_mem_usage=True,
146+
torch_dtype=weight_dtype,
147+
)
146148
text_encoder = text_encoder.eval()
147149

148150
# Get Clip Image Encoder

scripts/wan2.1/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,9 @@ def deepspeed_zero_init_disabled_context_manager():
842842
text_encoder = WanT5EncoderModel.from_pretrained(
843843
os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
844844
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
845-
).to(weight_dtype)
845+
low_cpu_mem_usage=True,
846+
torch_dtype=weight_dtype,
847+
)
846848
text_encoder = text_encoder.eval()
847849
# Get Vae
848850
vae = AutoencoderKLWan.from_pretrained(

scripts/wan2.1/train_lora.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,9 @@ def deepspeed_zero_init_disabled_context_manager():
841841
text_encoder = WanT5EncoderModel.from_pretrained(
842842
os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
843843
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
844-
).to(weight_dtype)
844+
low_cpu_mem_usage=True,
845+
torch_dtype=weight_dtype,
846+
)
845847
# Get Vae
846848
vae = AutoencoderKLWan.from_pretrained(
847849
os.path.join(args.pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),

scripts/wan2.1/train_reward_lora.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,9 @@ def deepspeed_zero_init_disabled_context_manager():
864864
text_encoder = WanT5EncoderModel.from_pretrained(
865865
os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
866866
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
867-
).to(weight_dtype)
867+
low_cpu_mem_usage=True,
868+
torch_dtype=weight_dtype,
869+
)
868870
text_encoder = text_encoder.eval()
869871
# Get Vae
870872
vae = AutoencoderKLWan.from_pretrained(

0 commit comments

Comments
 (0)