Skip to content

Commit 53116d4

Browse files
committed
Fix deterministic issue when getting pipeline dtype and device
1 parent aad69ac commit 53116d4

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,7 +1577,7 @@ def _get_signature_keys(cls, obj):
15771577
expected_modules.add(name)
15781578
optional_parameters.remove(name)
15791579

1580-
return expected_modules, optional_parameters
1580+
return sorted(expected_modules), sorted(optional_parameters)
15811581

15821582
@classmethod
15831583
def _get_signature_types(cls):
@@ -1619,10 +1619,12 @@ def components(self) -> Dict[str, Any]:
16191619
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
16201620
}
16211621

1622-
if set(components.keys()) != expected_modules:
1622+
actual = sorted(set(components.keys()))
1623+
expected = sorted(expected_modules)
1624+
if actual != expected:
16231625
raise ValueError(
16241626
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
1625-
f" {expected_modules} to be defined, but {components.keys()} are defined."
1627+
f" {expected} to be defined, but {actual} are defined."
16261628
)
16271629

16281630
return components

tests/pipelines/test_pipeline_utils.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
UNet2DConditionModel,
2020
)
2121
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
22-
from diffusers.utils.testing_utils import torch_device
22+
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
2323

2424

2525
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -585,3 +585,104 @@ def test_video_to_video(self):
585585
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
586586
_ = pipe(**inputs)
587587
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
588+
589+
590+
@require_torch_gpu
591+
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
592+
expected_pipe_device = torch.device("cuda:0")
593+
expected_pipe_dtype = torch.float64
594+
595+
def get_dummy_components_image_generation(self):
596+
cross_attention_dim = 8
597+
598+
torch.manual_seed(0)
599+
unet = UNet2DConditionModel(
600+
block_out_channels=(4, 8),
601+
layers_per_block=1,
602+
sample_size=32,
603+
in_channels=4,
604+
out_channels=4,
605+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
606+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
607+
cross_attention_dim=cross_attention_dim,
608+
norm_num_groups=2,
609+
)
610+
scheduler = DDIMScheduler(
611+
beta_start=0.00085,
612+
beta_end=0.012,
613+
beta_schedule="scaled_linear",
614+
clip_sample=False,
615+
set_alpha_to_one=False,
616+
)
617+
torch.manual_seed(0)
618+
vae = AutoencoderKL(
619+
block_out_channels=[4, 8],
620+
in_channels=3,
621+
out_channels=3,
622+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
623+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
624+
latent_channels=4,
625+
norm_num_groups=2,
626+
)
627+
torch.manual_seed(0)
628+
text_encoder_config = CLIPTextConfig(
629+
bos_token_id=0,
630+
eos_token_id=2,
631+
hidden_size=cross_attention_dim,
632+
intermediate_size=16,
633+
layer_norm_eps=1e-05,
634+
num_attention_heads=2,
635+
num_hidden_layers=2,
636+
pad_token_id=1,
637+
vocab_size=1000,
638+
)
639+
text_encoder = CLIPTextModel(text_encoder_config)
640+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
641+
642+
components = {
643+
"unet": unet,
644+
"scheduler": scheduler,
645+
"vae": vae,
646+
"text_encoder": text_encoder,
647+
"tokenizer": tokenizer,
648+
"safety_checker": None,
649+
"feature_extractor": None,
650+
"image_encoder": None,
651+
}
652+
return components
653+
654+
def test_deterministic_device(self):
655+
components = self.get_dummy_components_image_generation()
656+
657+
pipe = StableDiffusionPipeline(**components)
658+
pipe.to(device=torch_device, dtype=torch.float32)
659+
660+
pipe.unet.to(device="cpu")
661+
pipe.vae.to(device="cuda")
662+
pipe.text_encoder.to(device="cuda:0")
663+
664+
pipe_device = pipe.device
665+
666+
self.assertEqual(
667+
self.expected_pipe_device,
668+
pipe_device,
669+
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
670+
)
671+
672+
def test_deterministic_dtype(self):
673+
components = self.get_dummy_components_image_generation()
674+
675+
pipe = StableDiffusionPipeline(**components)
676+
pipe.to(device=torch_device, dtype=torch.float32)
677+
678+
pipe.unet.to(dtype=torch.float16)
679+
pipe.vae.to(dtype=torch.float32)
680+
pipe.text_encoder.to(dtype=torch.float64)
681+
682+
pipe_dtype = pipe.dtype
683+
684+
self.assertEqual(
685+
self.expected_pipe_dtype,
686+
pipe_dtype,
687+
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
688+
)

0 commit comments

Comments
 (0)