Skip to content

A strange thing happened when I wrote my own code to train Cotrolnet_sdxl, as soon as I did the first backpropagation, noise_pred became nan.Β #9422

@Li-Zn-H

Description

@Li-Zn-H

Describe the bug

A strange thing happened when I wrote my own code to train cotrolnet, as soon as I did the first backpropagation, noise_pred became nan. I did a lot of debugging, gradient decay, mixed precision training, removing ema and other parts, but the result was always nan once backpropagation was applied

Reproduction

my model and dataset setting

tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder='tokenizer')
tokenizer_2 = CLIPTokenizer.from_pretrained(sd_path, subfolder='tokenizer_2')
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder='text_encoder', torch_dtype=torch.float16).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(sd_path, subfolder='text_encoder_2', torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("/data2/lixq22/models/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder='unet', torch_dtype=torch.float16).to(device)
controlnet = ControlNetModel.from_pretrained(cn_path, torch_dtype=torch.float16).to(device)
scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
unet.requires_grad_(False)
vae.requires_grad_(False)
controlnet.train()

controlnet = DDP(controlnet, device_ids=[rank])
optimizer = torch.optim.AdamW(controlnet.parameters(), lr= 1e-5, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-8)
dataset = MyDataset()

train_sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

train_loader = DataLoader(
    dataset, 
    batch_size=2, 
    collate_fn=collate_fn, 
    sampler=train_sampler, 
    num_workers=1, 
    pin_memory=True
)
lr_scheduler = OneCycleLR(optimizer, max_lr=1e-4, total_steps=total_steps, pct_start=0.1, anneal_strategy='cos')
ema = ExponentialMovingAverage(controlnet.parameters(), decay=0.995)

# my forword and backpropagation code
for epoch in range(num_epochs):
    train_sampler.set_epoch(epoch)
    controlnet.train()
    optimizer.zero_grad(set_to_none="store_true")
    epoch_loss = 0
    tokenizers = [tokenizer, tokenizer_2]
    text_encoders = [text_encoder, text_encoder_2]
    for i, data in enumerate(train_loader):
        data = {k: (v.to(device).to(torch.float16) if isinstance(v, torch.Tensor) else v) for k, v in data.items()}
        with torch.no_grad():
            prompt_embeds_list = []
            for prompt, tokenizer, text_encoder in zip([data['prompts'], data['prompts']], tokenizers, text_encoders):
                text_input_ids = tokenizer(
                    prompt,
                    padding="max_length",
                    max_length=tokenizer.model_max_length,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids

                prompt_embeds = text_encoder(
                    text_input_ids.to(device),
                    output_hidden_states=True)
                pooled_prompt_embeds = prompt_embeds[0]  # [b,1280]
                prompt_embeds = prompt_embeds.hidden_states[-2]  # 
                prompt_embeds_list.append(prompt_embeds)
            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)  #  [b,77,2048]

            add_time_ids = list((1024, 1024) + (0, 0) + (1024, 1024))
            add_time_ids = torch.tensor([add_time_ids], dtype=torch.float16).to(device).repeat(len(data['prompts']), 1)  # [b,6]

            latents = vae.encode(data['pixel_values']).latent_dist.sample() * 0.18215   # [b,4,128,128]
            controlnet_image = data['conditioning_pixel_values']   # [b,3,1024,1024]
            bsz = len(latents)
            timesteps = torch.randint(0, 1000, (bsz,), device=latents.device).long()
        noise = torch.randn_like(latents)  # [b,4,128,128]
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)  # [b,4,128,128]

        down_block_res_samples, mid_block_res_sample = controlnet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
            controlnet_cond=controlnet_image,
            return_dict=False,
        )
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            down_block_additional_residuals=[sample for sample in down_block_res_samples],
            mid_block_additional_residual=mid_block_res_sample,
            added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
            return_dict=False,
        )[0]

        mse_loss .backward()

        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(controlnet.parameters(), max_norm=1.0)
            optimizer.step()
            ema.update(controlnet.parameters())
            lr_scheduler.step()
            optimizer.zero_grad()
            if rank == 0:
                current_lr = lr_scheduler.get_last_lr()[0]
                writer.add_scalar('Loss/train', weighted_loss.item(), global_step)
                writer.add_scalar('Learning Rate', current_lr, global_step)
                writer.flush()
            global_step += 1
            ema.copy_to(controlnet.parameters())
        epoch_loss += mse_loss .item()

Logs

No response

System Info

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • πŸ€— Diffusers version: 0.30.2
  • Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.12.2
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.2 (cpu)
  • Jax version: 0.4.25
  • JaxLib version: 0.4.25
  • Huggingface_hub version: 0.24.6
  • Transformers version: 4.39.3
  • Accelerate version: 0.28.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.2
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions