Skip to content

Commit 5bacc2f

Browse files
[SAG] Support more schedulers, add better error message and make tests faster (#6465)
* finish * finish --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 6ae7e81 commit 5bacc2f

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,11 @@ def __call__(
681681
self.scheduler.set_timesteps(num_inference_steps, device=device)
682682
timesteps = self.scheduler.timesteps
683683

684+
if timesteps.dtype not in [torch.int16, torch.int32, torch.int64]:
685+
raise ValueError(
686+
f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinlgestepScheduler'."
687+
)
688+
684689
# 5. Prepare latent variables
685690
num_channels_latents = self.unet.config.in_channels
686691
latents = self.prepare_latents(
@@ -830,14 +835,14 @@ def sag_masking(self, original_latents, attn_map, map_size, t, eps):
830835
degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask)
831836

832837
# Noise it again to match the noise level
833-
degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t)
838+
degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t[None])
834839

835840
return degraded_latents
836841

837842
# Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step
838843
# Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.)
839844
def pred_x0(self, sample, model_output, timestep):
840-
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
845+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
841846

842847
beta_prod_t = 1 - alpha_prod_t
843848
if self.scheduler.config.prediction_type == "epsilon":

tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from diffusers import (
2424
AutoencoderKL,
2525
DDIMScheduler,
26+
DEISMultistepScheduler,
27+
DPMSolverMultistepScheduler,
28+
EulerDiscreteScheduler,
2629
StableDiffusionSAGPipeline,
2730
UNet2DConditionModel,
2831
)
@@ -45,14 +48,15 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
4548
def get_dummy_components(self):
4649
torch.manual_seed(0)
4750
unet = UNet2DConditionModel(
48-
block_out_channels=(32, 64),
51+
block_out_channels=(4, 8),
4952
layers_per_block=2,
50-
sample_size=32,
53+
sample_size=8,
54+
norm_num_groups=1,
5155
in_channels=4,
5256
out_channels=4,
5357
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
5458
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
55-
cross_attention_dim=32,
59+
cross_attention_dim=8,
5660
)
5761
scheduler = DDIMScheduler(
5862
beta_start=0.00085,
@@ -63,7 +67,8 @@ def get_dummy_components(self):
6367
)
6468
torch.manual_seed(0)
6569
vae = AutoencoderKL(
66-
block_out_channels=[32, 64],
70+
block_out_channels=[4, 8],
71+
norm_num_groups=1,
6772
in_channels=3,
6873
out_channels=3,
6974
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
@@ -74,11 +79,11 @@ def get_dummy_components(self):
7479
text_encoder_config = CLIPTextConfig(
7580
bos_token_id=0,
7681
eos_token_id=2,
77-
hidden_size=32,
82+
hidden_size=8,
83+
num_hidden_layers=2,
7884
intermediate_size=37,
7985
layer_norm_eps=1e-05,
8086
num_attention_heads=4,
81-
num_hidden_layers=5,
8287
pad_token_id=1,
8388
vocab_size=1000,
8489
)
@@ -108,13 +113,35 @@ def get_dummy_inputs(self, device, seed=0):
108113
"num_inference_steps": 2,
109114
"guidance_scale": 1.0,
110115
"sag_scale": 1.0,
111-
"output_type": "numpy",
116+
"output_type": "np",
112117
}
113118
return inputs
114119

115120
def test_inference_batch_single_identical(self):
116121
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
117122

123+
@unittest.skip("Not necessary to test here.")
124+
def test_xformers_attention_forwardGenerator_pass(self):
125+
pass
126+
127+
def test_pipeline_different_schedulers(self):
128+
pipeline = self.pipeline_class(**self.get_dummy_components())
129+
inputs = self.get_dummy_inputs("cpu")
130+
131+
expected_image_size = (16, 16, 3)
132+
for scheduler_cls in [DDIMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler]:
133+
pipeline.scheduler = scheduler_cls.from_config(pipeline.scheduler.config)
134+
image = pipeline(**inputs).images[0]
135+
136+
shape = image.shape
137+
assert shape == expected_image_size
138+
139+
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
140+
141+
with self.assertRaises(ValueError):
142+
# Karras schedulers are not supported
143+
image = pipeline(**inputs).images[0]
144+
118145

119146
@nightly
120147
@require_torch_gpu

0 commit comments

Comments
 (0)