Skip to content

Fix CPU/CUDA device mismatch when training Klein edit with control_path#742

Merged
jaretburkett merged 1 commit intoostris:mainfrom
HuangYuChuh:fix/klein-edit-control-image-device-mismatch
Mar 25, 2026
Merged

Fix CPU/CUDA device mismatch when training Klein edit with control_path#742
jaretburkett merged 1 commit intoostris:mainfrom
HuangYuChuh:fix/klein-edit-control-image-device-mismatch

Conversation

@HuangYuChuh
Copy link
Contributor

@HuangYuChuh HuangYuChuh commented Mar 20, 2026

Problem

When training Klein models (e.g. flux2_klein_9b) with a paired dataset using control_path, the training crashes with:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Fixes #652

Root Cause

In load_model() (flux2_model.py), the VAE weights are loaded via:

vae_state_dict = load_file(vae_path, device="cpu")
vae.load_state_dict(vae_state_dict, assign=True)

The VAE is never explicitly moved to the training device (unlike the transformer and text encoder which receive .to(self.device_torch) calls). As a result, self.vae lives on CPU throughout training.

When get_noise_prediction() processes a batch with control images, it calls:

img_cond_seq_item, img_cond_seq_ids_item = encode_image_refs(
    self.vae, controls, limit_pixels=control_image_max_res
)

Inside encode_image_refs (src/sampling.py), the input is moved to ae.device (CPU) for encoding and the returned tensors remain on CPU. These are then concatenated with packed_latents / img_ids which are on CUDA — causing the device mismatch crash.

Fix

Move img_cond_seq and img_cond_seq_ids to the same device and dtype as img_input before concatenation:

# Before
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)

# After
img_input = torch.cat((img_input, img_cond_seq.to(img_input.device, img_input.dtype)), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids.to(img_input_ids.device)), dim=1)

This is a minimal, targeted fix at the concatenation point. It handles both the device mismatch (CPU→CUDA) and the dtype cast in one step, without changing VAE loading behaviour or affecting other code paths.

Reproduction

Config: any Klein training config (arch: flux2_klein_9b or flux2_klein_4b) with a control_path set in the dataset section (edit / paired-image training). Training will crash at the first step when img_cond_seq is not None.

When training Klein models with a `control_path` (edit/kontext-style
paired datasets), `encode_image_refs()` returns tensors that reside on
the VAE's device (CPU, since the VAE weights are loaded via
`load_file(..., device="cpu")` and are never explicitly moved to the
training device).  Concatenating those CPU tensors with the training
latents (`packed_latents`) that live on CUDA raises:

    RuntimeError: Expected all tensors to be on the same device

Fix: move `img_cond_seq` and `img_cond_seq_ids` to the same device
(and dtype) as `img_input` / `img_input_ids` before concatenation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@caiduoduo12138
Copy link

nice job! When set cache_latents_to_disk=true, it will occurs.

@jaretburkett jaretburkett merged commit 489b194 into ostris:main Mar 25, 2026
@jaretburkett
Copy link
Contributor

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error when training flux2-klein with control_folder

3 participants