Skip to content

Commit 76b0838

Browse files
lsteinCopilotJPPhoto
authored
Feature(backend): Add user toggle to run encoder models on CPU (#8777)
* feature(backend) Add user toggle to run encoder models on CPU Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Add frontend UI for CPU-only model execution toggle Co-authored-by: lstein <111189+lstein@users.noreply.github.com> * chore(frontend): remove package lock file created by npm --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
1 parent b7d7cd0 commit 76b0838

File tree

27 files changed

+1793
-721
lines changed

27 files changed

+1793
-721
lines changed

invokeai/app/invocations/cogview4_text_encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
CogView4ConditioningInfo,
1111
ConditioningFieldData,
1212
)
13-
from invokeai.backend.util.devices import TorchDevice
1413

1514
# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
1615
COGVIEW4_GLM_MAX_SEQ_LEN = 1024
@@ -37,6 +36,8 @@ class CogView4TextEncoderInvocation(BaseInvocation):
3736
@torch.no_grad()
3837
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
3938
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
39+
# Move embeddings to CPU for storage to save VRAM
40+
glm_embeds = glm_embeds.detach().to("cpu")
4041
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
4142
conditioning_name = context.conditioning.save(conditioning_data)
4243
return CogView4ConditioningOutput.build(conditioning_name)
@@ -85,7 +86,7 @@ def _glm_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Ten
8586
)
8687
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
8788
prompt_embeds = glm_text_encoder(
88-
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
89+
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
8990
).hidden_states[-2]
9091

9192
assert isinstance(prompt_embeds, torch.Tensor)

invokeai/app/invocations/compel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
103103
textual_inversion_manager=ti_manager,
104104
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
105105
truncate_long_prompts=False,
106-
device=TorchDevice.choose_torch_device(),
106+
device=text_encoder.device, # Use the device the model is actually on
107107
split_long_text_mode=SplitLongTextMode.SENTENCES,
108108
)
109109

@@ -212,7 +212,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
212212
truncate_long_prompts=False, # TODO:
213213
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
214214
requires_pooled=get_pooled,
215-
device=TorchDevice.choose_torch_device(),
215+
device=text_encoder.device, # Use the device the model is actually on
216216
split_long_text_mode=SplitLongTextMode.SENTENCES,
217217
)
218218

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
5858
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
5959
t5_embeddings = self._t5_encode(context)
6060
clip_embeddings = self._clip_encode(context)
61+
62+
# Move embeddings to CPU for storage to save VRAM
63+
# They will be moved to the appropriate device when used by the denoiser
64+
t5_embeddings = t5_embeddings.detach().to("cpu")
65+
clip_embeddings = clip_embeddings.detach().to("cpu")
66+
6167
conditioning_data = ConditioningFieldData(
6268
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
6369
)

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
2222
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
2323
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
24-
from invokeai.backend.util.devices import TorchDevice
2524

2625
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
2726
SD3_T5_MAX_SEQ_LEN = 256
@@ -69,6 +68,15 @@ def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
6968
if self.t5_encoder is not None:
7069
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
7170

71+
# Move all embeddings to CPU for storage to save VRAM
72+
# They will be moved to the appropriate device when used by the denoiser
73+
clip_l_embeddings = clip_l_embeddings.detach().to("cpu")
74+
clip_l_pooled_embeddings = clip_l_pooled_embeddings.detach().to("cpu")
75+
clip_g_embeddings = clip_g_embeddings.detach().to("cpu")
76+
clip_g_pooled_embeddings = clip_g_pooled_embeddings.detach().to("cpu")
77+
if t5_embeddings is not None:
78+
t5_embeddings = t5_embeddings.detach().to("cpu")
79+
7280
conditioning_data = ConditioningFieldData(
7381
conditionings=[
7482
SD3ConditioningInfo(
@@ -117,7 +125,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
117125
f" {max_seq_len} tokens: {removed_text}"
118126
)
119127

120-
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
128+
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
121129

122130
assert isinstance(prompt_embeds, torch.Tensor)
123131
return prompt_embeds
@@ -180,7 +188,7 @@ def _clip_encode(
180188
f" {tokenizer_max_length} tokens: {removed_text}"
181189
)
182190
prompt_embeds = clip_text_encoder(
183-
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
191+
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
184192
)
185193
pooled_prompt_embeds = prompt_embeds[0]
186194
prompt_embeds = prompt_embeds.hidden_states[-2]

invokeai/app/invocations/z_image_text_encoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class ZImageTextEncoderInvocation(BaseInvocation):
5757
@torch.no_grad()
5858
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
5959
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
60+
# Move embeddings to CPU for storage to save VRAM
61+
prompt_embeds = prompt_embeds.detach().to("cpu")
6062
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
6163
conditioning_name = context.conditioning.save(conditioning_data)
6264
return ZImageConditioningOutput(
@@ -69,7 +71,6 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
6971
Based on the ZImagePipeline._encode_prompt method from diffusers.
7072
"""
7173
prompt = self.prompt
72-
device = TorchDevice.choose_torch_device()
7374

7475
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
7576
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
@@ -78,6 +79,9 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
7879
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
7980
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
8081

82+
# Use the device that the text_encoder is actually on
83+
device = text_encoder.device
84+
8185
# Apply LoRA models to the text encoder
8286
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
8387
exit_stack.enter_context(

invokeai/app/services/model_records/model_records_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
8888
default_settings: Optional[MainModelDefaultSettings | LoraModelDefaultSettings | ControlAdapterDefaultSettings] = (
8989
Field(description="Default settings for this model", default=None)
9090
)
91+
cpu_only: Optional[bool] = Field(description="Whether this model should run on CPU only", default=None)
9192

9293
# Checkpoint-specific changes
9394
# TODO(MM2): Should we expose these? Feels footgun-y...

invokeai/backend/flux/modules/conditioner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from torch import Tensor, nn
44
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
55

6-
from invokeai.backend.util.devices import TorchDevice
7-
86

97
class HFEncoder(nn.Module):
108
def __init__(
@@ -33,8 +31,11 @@ def forward(self, text: list[str]) -> Tensor:
3331
return_tensors="pt",
3432
)
3533

34+
# Move inputs to the same device as the model to support cpu_only models
35+
model_device = next(self.hf_module.parameters()).device
36+
3637
outputs = self.hf_module(
37-
input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()),
38+
input_ids=batch_encoding["input_ids"].to(model_device),
3839
attention_mask=None,
3940
output_hidden_states=False,
4041
)

invokeai/backend/model_manager/configs/clip_embed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
4141
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
4242
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
4343
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
44+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
4445

4546
@classmethod
4647
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:

invokeai/backend/model_manager/configs/clip_vision.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
2828
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
2929
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
3030
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
31+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
3132

3233
@classmethod
3334
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:

invokeai/backend/model_manager/configs/llava_onevision.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
2525

2626
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
2727
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
28+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
2829

2930
@classmethod
3031
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:

0 commit comments

Comments
 (0)