Skip to content

Commit 489b194

Browse files
HuangYuChuhclaude
andauthored
Fix CPU/CUDA device mismatch in Klein edit control image encoding (#742)
When training Klein models with a `control_path` (edit/kontext-style paired datasets), `encode_image_refs()` returns tensors that reside on the VAE's device (CPU, since the VAE weights are loaded via `load_file(..., device="cpu")` and are never explicitly moved to the training device). Concatenating those CPU tensors with the training latents (`packed_latents`) that live on CUDA raises: RuntimeError: Expected all tensors to be on the same device Fix: move `img_cond_seq` and `img_cond_seq_ids` to the same device (and dtype) as `img_input` / `img_input_ids` before concatenation. Co-authored-by: HuangYuChuh <HuangYuChuh@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 89d2090 commit 489b194

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

extensions_built_in/diffusion_models/flux2/flux2_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def get_noise_prediction(
412412
assert img_cond_seq_ids is not None, (
413413
"You need to provide either both or neither of the sequence conditioning"
414414
)
415-
img_input = torch.cat((img_input, img_cond_seq), dim=1)
416-
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
415+
img_input = torch.cat((img_input, img_cond_seq.to(img_input.device, img_input.dtype)), dim=1)
416+
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids.to(img_input_ids.device)), dim=1)
417417

418418
guidance_vec = torch.full(
419419
(img_input.shape[0],),

0 commit comments

Comments
 (0)