Skip to content

Commit ec2eedb

Browse files
fix(flux2): resolve device mismatch in Klein text encoder (#8851)
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 77e1ac1 commit ec2eedb

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

invokeai/app/invocations/flux2_klein_text_encoder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torc
100100
Shape: (1, hidden_size)
101101
"""
102102
prompt = self.prompt
103-
device = TorchDevice.choose_torch_device()
104103

105104
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
106105
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
@@ -109,6 +108,9 @@ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torc
109108
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
110109
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
111110

111+
# you can now define the device, as the text_encoder exists here
112+
device = text_encoder.device
113+
112114
# Apply LoRA models to the text encoder
113115
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
114116
exit_stack.enter_context(
@@ -157,28 +159,27 @@ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torc
157159
max_length=self.max_seq_len,
158160
)
159161

160-
input_ids = inputs["input_ids"]
161-
attention_mask = inputs["attention_mask"]
162+
input_ids = inputs["input_ids"].to(device)
163+
attention_mask = inputs["attention_mask"].to(device)
162164

163165
# Move to device
164166
input_ids = input_ids.to(device)
165167
attention_mask = attention_mask.to(device)
166168

167169
# Forward pass through the model - matching diffusers exactly
170+
# Explicitly move inputs to the same device as the text_encoder
168171
outputs = text_encoder(
169172
input_ids=input_ids,
170173
attention_mask=attention_mask,
171174
output_hidden_states=True,
172175
use_cache=False,
173176
)
174-
175177
# Validate hidden_states output
176178
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
177179
raise RuntimeError(
178180
"Text encoder did not return hidden_states. "
179181
"Ensure output_hidden_states=True is supported by this model."
180182
)
181-
182183
num_hidden_layers = len(outputs.hidden_states)
183184

184185
# Extract and stack hidden states - EXACTLY like diffusers:

0 commit comments

Comments
 (0)