- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6.5k
 
Description
Describe the bug
I observed inconsistent results when running Stable Diffusion (v1.4, v2.0-base) multiple times with changes of batch size. Please see the code and log below. I am not sure whether it is a bug or a model property.
from typing import List
from six import iteritems
import torch
from torch import Tensor
from diffusers import StableDiffusionPipeline
device = "cuda"
# device = "cpu"
repo = "CompVis/stable-diffusion-v1-4"
# repo = "sd-legacy/stable-diffusion-v1-5"
# repo = "stabilityai/stable-diffusion-2-base"
torch.set_printoptions(linewidth=10000)
# ------------------------------------------ #
pipeline = StableDiffusionPipeline.from_pretrained(
    repo, torch_dtype=torch.float32)
pipeline = pipeline.to(device)
pipeline.vae.eval()
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.eval()
pipeline.text_encoder.requires_grad_(False)
pipeline.unet.eval()
pipeline.unet.requires_grad_(False)
# pipeline.unet.enable_xformers_memory_efficient_attention()
# ------------------------------------------ #
# ------------------------------------------ #
def net_fn(
    x_t: Tensor, t: Tensor,
    text_embeds: Tensor,
    neg_text_embeds: Tensor,
    guidance_scale: float = 7.5
):
    if guidance_scale != 1:
        x_t = torch.cat([x_t, x_t], dim=0)
        t = torch.cat([t, t], dim=0)
        text_embeds_t = torch.cat([text_embeds, neg_text_embeds], dim=0)
    else:
        text_embeds_t = text_embeds
    x_1_t = pipeline.unet(
        x_t, t,
        encoder_hidden_states=text_embeds_t,
        timestep_cond=None,
        cross_attention_kwargs=None,
        added_cond_kwargs=None,
        return_dict=False,
    )[0]
    if guidance_scale != 1:
        x_1_t_cond, x_1_t_null = x_1_t.chunk(2)
        x_1_t = x_1_t_null + guidance_scale * (x_1_t_cond - x_1_t_null)
    return x_1_t
def encode_text(
    prompts: List[str],
):
    text_embeds, neg_text_embeds = pipeline.encode_prompt(
        prompt=prompts,
        device=device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
        negative_prompt=None,
        lora_scale=None,
        clip_skip=None,
    )
    return text_embeds, neg_text_embeds
# ------------------------------------------ #
# batch_size = 3
# batch_size = 7
batch_size = 20
# batch_size = 30
# batch_size = 50
with torch.no_grad():
    # ------------------------------------------ #
    selected_prompts = ["a green colored rabbit"]
    text_embeds, neg_text_embeds = encode_text(selected_prompts)
    print(f"text_embed.shape: {text_embeds.shape}")
    print(f"neg_text_embed.shape: {neg_text_embeds.shape}")
    net_kwargs = {
        'text_embeds': text_embeds.repeat(batch_size, 1, 1),
        'neg_text_embeds': neg_text_embeds.repeat(batch_size, 1, 1),
    }
    # ------------------------------------------ #
    # 1
    # -------------------------------------------- #
    x_1 = torch.randn([batch_size, 4, 64, 64], device=device, dtype=torch.float32)
    t_1 = 999 * torch.ones([batch_size], dtype=torch.long, device=device)
    pred_noise = net_fn(
        x_1, t_1, **net_kwargs
    )
    # -------------------------------------------- #
    # 2
    # -------------------------------------------- #
    pred_noise_2 = net_fn(
        x_1, t_1, **net_kwargs
    )
    print()
    print(f"max diff between 1 and 2: {(pred_noise - pred_noise_2).abs().max()}")
    print(f"total diff between 1 and 2: {(pred_noise - pred_noise_2).abs().sum() / batch_size}")
    # -------------------------------------------- #
    # 3
    # -------------------------------------------- #
    x_1_dup = torch.cat([x_1, x_1], dim=0)
    t_1_dup = torch.cat([t_1, t_1], dim=0)
    net_kwargs_dup = {key: torch.cat([val, val], dim=0) for key, val in iteritems(net_kwargs)}
    pred_noise_dup = net_fn(
        x_1_dup, t_1_dup, **net_kwargs_dup
    )
    pred_noise_3 = pred_noise_dup[:len(x_1)]
    print()
    print(f"max diff between 1 and 3: {(pred_noise - pred_noise_3).abs().max()}")
    print(f"total diff between 1 and 3: {(pred_noise - pred_noise_3).abs().sum() / batch_size}")
    # -------------------------------------------- #
    # 4
    # -------------------------------------------- #
    x_1_inc1 = torch.cat([x_1, 10000 * torch.ones_like(x_1[:1])], dim=0)
    t_1_inc1 = torch.cat([t_1, torch.zeros_like(t_1)[:1]], dim=0)
    net_kwargs_inc1 = {key: torch.cat([val, val[:1]], dim=0) for key, val in iteritems(net_kwargs)}
    pred_noise_inc1 = net_fn(
        x_1_inc1, t_1_inc1, **net_kwargs_inc1
    )
    pred_noise_4 = pred_noise_inc1[:len(x_1)]
    print()
    print(f"max diff between 1 and 4: {(pred_noise - pred_noise_4).abs().max()}")
    print(f"total diff between 1 and 4: {(pred_noise - pred_noise_4).abs().sum() / batch_size}")
    # -------------------------------------------- #
    # 5
    # -------------------------------------------- #
    z_1 = torch.ones([2 * batch_size, *x_1.shape[1:]], dtype=x_1.dtype, device=x_1.device)
    u_1 = torch.zeros([2 * batch_size, *t_1.shape[1:]], dtype=t_1.dtype, device=t_1.device)
    x_1_aux = torch.cat([x_1, z_1], dim=0)
    t_1_aux = torch.cat([t_1, u_1], dim=0)
    net_kwargs_aux = {key: torch.cat([val, val.repeat(2, 1, 1)], dim=0) for key, val in iteritems(net_kwargs)}
    pred_noise_aux = net_fn(
        x_1_aux, t_1_aux, **net_kwargs_aux
    )
    pred_noise_5 = pred_noise_aux[:len(x_1)]
    print()
    print(f"max diff between 1 and 5: {(pred_noise - pred_noise_5).abs().max()}")
    print(f"total diff between 1 and 5: {(pred_noise - pred_noise_5).abs().sum() / batch_size}")
    # -------------------------------------------- #
    # 6
    # -------------------------------------------- #
    rand_ids = torch.randperm(batch_size, dtype=torch.long, device=device)
    inv_ids = torch.argsort(rand_ids)
    net_kwargs_perm = {key: val[rand_ids] for key, val in iteritems(net_kwargs)}
    pred_noise_perm = net_fn(
        x_1[rand_ids], t_1[rand_ids], **net_kwargs_perm
    )
    pred_noise_6 = pred_noise_perm[inv_ids]
    print()
    print(f"max diff between 1 and 6: {(pred_noise - pred_noise_6).abs().max()}")
    print(f"total diff between 1 and 6: {(pred_noise - pred_noise_6).abs().sum() / batch_size}")
    # -------------------------------------------- #
Log
text_embed.shape: torch.Size([1, 77, 768])
neg_text_embed.shape: torch.Size([1, 77, 768])
max diff between 1 and 2: 0.0
total diff between 1 and 2: 0.0
max diff between 1 and 3: 0.0025451183319091797
total diff between 1 and 3: 5.266808986663818
max diff between 1 and 4: 0.0027496814727783203
total diff between 1 and 4: 5.184873104095459
max diff between 1 and 5: 0.0030808448791503906
total diff between 1 and 5: 5.262452602386475
max diff between 1 and 6: 0.0015063285827636719
total diff between 1 and 6: 0.30538126826286316
System Info
GPU: H100 80GB
torch: 2.5.1+cu118
torchvision: 0.20.1+cu118
diffusers: 0.32.2
transformers: 4.48.1
tokenizers: 0.21.0