|
21 | 21 | from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX |
22 | 22 | from invokeai.backend.patches.model_patch_raw import ModelPatchRaw |
23 | 23 | from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo |
24 | | -from invokeai.backend.util.devices import TorchDevice |
25 | 24 |
|
26 | 25 | # The SD3 T5 Max Sequence Length set based on the default in diffusers. |
27 | 26 | SD3_T5_MAX_SEQ_LEN = 256 |
@@ -69,6 +68,15 @@ def invoke(self, context: InvocationContext) -> SD3ConditioningOutput: |
69 | 68 | if self.t5_encoder is not None: |
70 | 69 | t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN) |
71 | 70 |
|
| 71 | + # Move all embeddings to CPU for storage to save VRAM |
| 72 | + # They will be moved to the appropriate device when used by the denoiser |
| 73 | + clip_l_embeddings = clip_l_embeddings.detach().to("cpu") |
| 74 | + clip_l_pooled_embeddings = clip_l_pooled_embeddings.detach().to("cpu") |
| 75 | + clip_g_embeddings = clip_g_embeddings.detach().to("cpu") |
| 76 | + clip_g_pooled_embeddings = clip_g_pooled_embeddings.detach().to("cpu") |
| 77 | + if t5_embeddings is not None: |
| 78 | + t5_embeddings = t5_embeddings.detach().to("cpu") |
| 79 | + |
72 | 80 | conditioning_data = ConditioningFieldData( |
73 | 81 | conditionings=[ |
74 | 82 | SD3ConditioningInfo( |
@@ -117,7 +125,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens |
117 | 125 | f" {max_seq_len} tokens: {removed_text}" |
118 | 126 | ) |
119 | 127 |
|
120 | | - prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0] |
| 128 | + prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0] |
121 | 129 |
|
122 | 130 | assert isinstance(prompt_embeds, torch.Tensor) |
123 | 131 | return prompt_embeds |
@@ -180,7 +188,7 @@ def _clip_encode( |
180 | 188 | f" {tokenizer_max_length} tokens: {removed_text}" |
181 | 189 | ) |
182 | 190 | prompt_embeds = clip_text_encoder( |
183 | | - input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True |
| 191 | + input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True |
184 | 192 | ) |
185 | 193 | pooled_prompt_embeds = prompt_embeds[0] |
186 | 194 | prompt_embeds = prompt_embeds.hidden_states[-2] |
|
0 commit comments