Skip to content

Commit 280fb32

Browse files
Fix FLUX2 Klein load-time VRAM spikes on low-memory GPUs.
Keep the transformer and Qwen text encoder off CUDA during initial load/quantization in low-VRAM mode so model startup avoids full-model OOM before offloading and quantization can take effect. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent de7d22c commit 280fb32

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

extensions_built_in/diffusion_models/flux2/flux2_klein_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,14 @@ def load_te(self):
4444
self.flux2_klein_te_path,
4545
torch_dtype=dtype,
4646
)
47-
text_encoder.to(self.device_torch, dtype=dtype)
48-
49-
flush()
50-
5147
if self.model_config.quantize_te:
5248
self.print_and_status_update("Quantizing Qwen3")
53-
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
49+
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
5450
freeze(text_encoder)
5551
flush()
52+
elif not self.model_config.low_vram:
53+
text_encoder.to(self.device_torch, dtype=dtype)
54+
flush()
5655

5756
if (
5857
self.model_config.layer_offloading

extensions_built_in/diffusion_models/flux2/flux2_model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ def load_model(self):
155155

156156
transformer.load_state_dict(transformer_state_dict, assign=True)
157157

158-
transformer.to(self.quantize_device, dtype=dtype)
159-
160158
if self.model_config.quantize:
161159
# patch the state dict method
162160
patch_dequantization_on_save(transformer)
161+
# Avoid full-model peak VRAM allocation before quantization.
162+
self.print_and_status_update("Keeping transformer on CPU for quantization")
163163
self.print_and_status_update("Quantizing Transformer")
164164
quantize_model(self, transformer)
165165
flush()
@@ -234,10 +234,16 @@ def load_model(self):
234234

235235
flush()
236236
# just to make sure everything is on the right device and dtype
237-
text_encoder[0].to(self.device_torch)
237+
if self.model_config.low_vram:
238+
text_encoder[0].to("cpu")
239+
else:
240+
text_encoder[0].to(self.device_torch)
238241
text_encoder[0].requires_grad_(False)
239242
text_encoder[0].eval()
240-
pipe.transformer = pipe.transformer.to(self.device_torch)
243+
if self.model_config.low_vram:
244+
pipe.transformer = pipe.transformer.to("cpu")
245+
else:
246+
pipe.transformer = pipe.transformer.to(self.device_torch)
241247
flush()
242248

243249
# save it to the model class

0 commit comments

Comments
 (0)