Skip to content

Inconsistent results of Stable Diffusion when batch size is different #11016

@clarken92

Description

@clarken92

Describe the bug

I observed inconsistent results when running Stable Diffusion (v1.4, v2.0-base) multiple times with changes of batch size. Please see the code and log below. I am not sure whether it is a bug or a model property.

from typing import List
from six import iteritems

import torch
from torch import Tensor

from diffusers import StableDiffusionPipeline


device = "cuda"
# device = "cpu"

repo = "CompVis/stable-diffusion-v1-4"
# repo = "sd-legacy/stable-diffusion-v1-5"
# repo = "stabilityai/stable-diffusion-2-base"

torch.set_printoptions(linewidth=10000)

# ------------------------------------------ #
pipeline = StableDiffusionPipeline.from_pretrained(
    repo, torch_dtype=torch.float32)
pipeline = pipeline.to(device)

pipeline.vae.eval()
pipeline.vae.requires_grad_(False)

pipeline.text_encoder.eval()
pipeline.text_encoder.requires_grad_(False)

pipeline.unet.eval()
pipeline.unet.requires_grad_(False)

# pipeline.unet.enable_xformers_memory_efficient_attention()
# ------------------------------------------ #


# ------------------------------------------ #
def net_fn(
    x_t: Tensor, t: Tensor,
    text_embeds: Tensor,
    neg_text_embeds: Tensor,
    guidance_scale: float = 7.5
):

    if guidance_scale != 1:
        x_t = torch.cat([x_t, x_t], dim=0)
        t = torch.cat([t, t], dim=0)
        text_embeds_t = torch.cat([text_embeds, neg_text_embeds], dim=0)
    else:
        text_embeds_t = text_embeds

    x_1_t = pipeline.unet(
        x_t, t,
        encoder_hidden_states=text_embeds_t,
        timestep_cond=None,
        cross_attention_kwargs=None,
        added_cond_kwargs=None,
        return_dict=False,
    )[0]

    if guidance_scale != 1:
        x_1_t_cond, x_1_t_null = x_1_t.chunk(2)
        x_1_t = x_1_t_null + guidance_scale * (x_1_t_cond - x_1_t_null)

    return x_1_t


def encode_text(
    prompts: List[str],
):
    text_embeds, neg_text_embeds = pipeline.encode_prompt(
        prompt=prompts,
        device=device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
        negative_prompt=None,
        lora_scale=None,
        clip_skip=None,
    )

    return text_embeds, neg_text_embeds
# ------------------------------------------ #


# batch_size = 3
# batch_size = 7
batch_size = 20
# batch_size = 30
# batch_size = 50


with torch.no_grad():
    # ------------------------------------------ #
    selected_prompts = ["a green colored rabbit"]
    text_embeds, neg_text_embeds = encode_text(selected_prompts)

    print(f"text_embed.shape: {text_embeds.shape}")
    print(f"neg_text_embed.shape: {neg_text_embeds.shape}")

    net_kwargs = {
        'text_embeds': text_embeds.repeat(batch_size, 1, 1),
        'neg_text_embeds': neg_text_embeds.repeat(batch_size, 1, 1),
    }
    # ------------------------------------------ #

    # 1
    # -------------------------------------------- #
    x_1 = torch.randn([batch_size, 4, 64, 64], device=device, dtype=torch.float32)
    t_1 = 999 * torch.ones([batch_size], dtype=torch.long, device=device)

    pred_noise = net_fn(
        x_1, t_1, **net_kwargs
    )
    # -------------------------------------------- #


    # 2
    # -------------------------------------------- #
    pred_noise_2 = net_fn(
        x_1, t_1, **net_kwargs
    )

    print()
    print(f"max diff between 1 and 2: {(pred_noise - pred_noise_2).abs().max()}")
    print(f"total diff between 1 and 2: {(pred_noise - pred_noise_2).abs().sum() / batch_size}")
    # -------------------------------------------- #


    # 3
    # -------------------------------------------- #
    x_1_dup = torch.cat([x_1, x_1], dim=0)
    t_1_dup = torch.cat([t_1, t_1], dim=0)
    net_kwargs_dup = {key: torch.cat([val, val], dim=0) for key, val in iteritems(net_kwargs)}

    pred_noise_dup = net_fn(
        x_1_dup, t_1_dup, **net_kwargs_dup
    )
    pred_noise_3 = pred_noise_dup[:len(x_1)]

    print()
    print(f"max diff between 1 and 3: {(pred_noise - pred_noise_3).abs().max()}")
    print(f"total diff between 1 and 3: {(pred_noise - pred_noise_3).abs().sum() / batch_size}")
    # -------------------------------------------- #


    # 4
    # -------------------------------------------- #
    x_1_inc1 = torch.cat([x_1, 10000 * torch.ones_like(x_1[:1])], dim=0)
    t_1_inc1 = torch.cat([t_1, torch.zeros_like(t_1)[:1]], dim=0)
    net_kwargs_inc1 = {key: torch.cat([val, val[:1]], dim=0) for key, val in iteritems(net_kwargs)}

    pred_noise_inc1 = net_fn(
        x_1_inc1, t_1_inc1, **net_kwargs_inc1
    )
    pred_noise_4 = pred_noise_inc1[:len(x_1)]

    print()
    print(f"max diff between 1 and 4: {(pred_noise - pred_noise_4).abs().max()}")
    print(f"total diff between 1 and 4: {(pred_noise - pred_noise_4).abs().sum() / batch_size}")
    # -------------------------------------------- #


    # 5
    # -------------------------------------------- #
    z_1 = torch.ones([2 * batch_size, *x_1.shape[1:]], dtype=x_1.dtype, device=x_1.device)
    u_1 = torch.zeros([2 * batch_size, *t_1.shape[1:]], dtype=t_1.dtype, device=t_1.device)

    x_1_aux = torch.cat([x_1, z_1], dim=0)
    t_1_aux = torch.cat([t_1, u_1], dim=0)
    net_kwargs_aux = {key: torch.cat([val, val.repeat(2, 1, 1)], dim=0) for key, val in iteritems(net_kwargs)}

    pred_noise_aux = net_fn(
        x_1_aux, t_1_aux, **net_kwargs_aux
    )
    pred_noise_5 = pred_noise_aux[:len(x_1)]

    print()
    print(f"max diff between 1 and 5: {(pred_noise - pred_noise_5).abs().max()}")
    print(f"total diff between 1 and 5: {(pred_noise - pred_noise_5).abs().sum() / batch_size}")
    # -------------------------------------------- #


    # 6
    # -------------------------------------------- #
    rand_ids = torch.randperm(batch_size, dtype=torch.long, device=device)
    inv_ids = torch.argsort(rand_ids)

    net_kwargs_perm = {key: val[rand_ids] for key, val in iteritems(net_kwargs)}

    pred_noise_perm = net_fn(
        x_1[rand_ids], t_1[rand_ids], **net_kwargs_perm
    )
    pred_noise_6 = pred_noise_perm[inv_ids]

    print()
    print(f"max diff between 1 and 6: {(pred_noise - pred_noise_6).abs().max()}")
    print(f"total diff between 1 and 6: {(pred_noise - pred_noise_6).abs().sum() / batch_size}")
    # -------------------------------------------- #

Log

text_embed.shape: torch.Size([1, 77, 768])
neg_text_embed.shape: torch.Size([1, 77, 768])

max diff between 1 and 2: 0.0
total diff between 1 and 2: 0.0

max diff between 1 and 3: 0.0025451183319091797
total diff between 1 and 3: 5.266808986663818

max diff between 1 and 4: 0.0027496814727783203
total diff between 1 and 4: 5.184873104095459

max diff between 1 and 5: 0.0030808448791503906
total diff between 1 and 5: 5.262452602386475

max diff between 1 and 6: 0.0015063285827636719
total diff between 1 and 6: 0.30538126826286316

System Info

GPU: H100 80GB
torch: 2.5.1+cu118
torchvision: 0.20.1+cu118
diffusers: 0.32.2
transformers: 4.48.1
tokenizers: 0.21.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions