-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
A strange thing happened when I wrote my own code to train cotrolnet, as soon as I did the first backpropagation, noise_pred became nan. I did a lot of debugging, gradient decay, mixed precision training, removing ema and other parts, but the result was always nan once backpropagation was applied
Reproduction
my model and dataset setting
tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder='tokenizer')
tokenizer_2 = CLIPTokenizer.from_pretrained(sd_path, subfolder='tokenizer_2')
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder='text_encoder', torch_dtype=torch.float16).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(sd_path, subfolder='text_encoder_2', torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("/data2/lixq22/models/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder='unet', torch_dtype=torch.float16).to(device)
controlnet = ControlNetModel.from_pretrained(cn_path, torch_dtype=torch.float16).to(device)
scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
unet.requires_grad_(False)
vae.requires_grad_(False)
controlnet.train()
controlnet = DDP(controlnet, device_ids=[rank])
optimizer = torch.optim.AdamW(controlnet.parameters(), lr= 1e-5, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-8)
dataset = MyDataset()
train_sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(
dataset,
batch_size=2,
collate_fn=collate_fn,
sampler=train_sampler,
num_workers=1,
pin_memory=True
)
lr_scheduler = OneCycleLR(optimizer, max_lr=1e-4, total_steps=total_steps, pct_start=0.1, anneal_strategy='cos')
ema = ExponentialMovingAverage(controlnet.parameters(), decay=0.995)
# my forword and backpropagation code
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch)
controlnet.train()
optimizer.zero_grad(set_to_none="store_true")
epoch_loss = 0
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
for i, data in enumerate(train_loader):
data = {k: (v.to(device).to(torch.float16) if isinstance(v, torch.Tensor) else v) for k, v in data.items()}
with torch.no_grad():
prompt_embeds_list = []
for prompt, tokenizer, text_encoder in zip([data['prompts'], data['prompts']], tokenizers, text_encoders):
text_input_ids = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
prompt_embeds = text_encoder(
text_input_ids.to(device),
output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0] # [b,1280]
prompt_embeds = prompt_embeds.hidden_states[-2] #
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # [b,77,2048]
add_time_ids = list((1024, 1024) + (0, 0) + (1024, 1024))
add_time_ids = torch.tensor([add_time_ids], dtype=torch.float16).to(device).repeat(len(data['prompts']), 1) # [b,6]
latents = vae.encode(data['pixel_values']).latent_dist.sample() * 0.18215 # [b,4,128,128]
controlnet_image = data['conditioning_pixel_values'] # [b,3,1024,1024]
bsz = len(latents)
timesteps = torch.randint(0, 1000, (bsz,), device=latents.device).long()
noise = torch.randn_like(latents) # [b,4,128,128]
noisy_latents = scheduler.add_noise(latents, noise, timesteps) # [b,4,128,128]
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
controlnet_cond=controlnet_image,
return_dict=False,
)
model_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=prompt_embeds,
down_block_additional_residuals=[sample for sample in down_block_res_samples],
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
return_dict=False,
)[0]
mse_loss .backward()
if (i + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(controlnet.parameters(), max_norm=1.0)
optimizer.step()
ema.update(controlnet.parameters())
lr_scheduler.step()
optimizer.zero_grad()
if rank == 0:
current_lr = lr_scheduler.get_last_lr()[0]
writer.add_scalar('Loss/train', weighted_loss.item(), global_step)
writer.add_scalar('Learning Rate', current_lr, global_step)
writer.flush()
global_step += 1
ema.copy_to(controlnet.parameters())
epoch_loss += mse_loss .item()
Logs
No response
System Info
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- π€ Diffusers version: 0.30.2
- Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.12.2
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.2 (cpu)
- Jax version: 0.4.25
- JaxLib version: 0.4.25
- Huggingface_hub version: 0.24.6
- Transformers version: 4.39.3
- Accelerate version: 0.28.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.4.2
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB - Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working