diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index f45e0a51d226..dc774d145c83 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -839,9 +839,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(
- isinstance(tok, str) for tok in inserting_toks
- ), "All elements in inserting_toks should be strings."
+ assert all(isinstance(tok, str) for tok in inserting_toks), (
+ "All elements in inserting_toks should be strings."
+ )
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -1605,7 +1605,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 8cd1d777c00c..95ba53391cf3 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -200,7 +200,8 @@ def save_model_card(
"diffusers",
"diffusers-training",
lora,
- "template:sd-lora" "stable-diffusion",
+ "template:sd-lora",
+ "stable-diffusion",
"stable-diffusion-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
@@ -724,9 +725,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(
- isinstance(tok, str) for tok in inserting_toks
- ), "All elements in inserting_toks should be strings."
+ assert all(isinstance(tok, str) for tok in inserting_toks), (
+ "All elements in inserting_toks should be strings."
+ )
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -746,9 +747,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
.to(dtype=self.dtype)
* std_token_embedding
)
- self.embeddings_settings[
- f"original_embeddings_{idx}"
- ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"original_embeddings_{idx}"] = (
+ text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1322,7 +1323,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index f8253715e64d..236dc20d621c 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -890,9 +890,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
- assert all(
- isinstance(tok, str) for tok in inserting_toks
- ), "All elements in inserting_toks should be strings."
+ assert all(isinstance(tok, str) for tok in inserting_toks), (
+ "All elements in inserting_toks should be strings."
+ )
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -912,9 +912,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
.to(dtype=self.dtype)
* std_token_embedding
)
- self.embeddings_settings[
- f"original_embeddings_{idx}"
- ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"original_embeddings_{idx}"] = (
+ text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1647,7 +1647,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py
index df44a0a63aeb..d71d9ccbb83e 100644
--- a/examples/amused/train_amused.py
+++ b/examples/amused/train_amused.py
@@ -720,7 +720,7 @@ def load_model_hook(models, input_dir):
# Train!
logger.info("***** Running training *****")
logger.info(f" Num training steps = {args.max_train_steps}")
- logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index eed8305f4fbc..35d4d156225d 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -1138,7 +1138,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index 74ea98cbac5e..bf09ff02ae38 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -1159,7 +1159,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py
index df736956485b..81f9527b4703 100644
--- a/examples/community/adaptive_mask_inpainting.py
+++ b/examples/community/adaptive_mask_inpainting.py
@@ -1103,7 +1103,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `default_mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/hd_painter.py b/examples/community/hd_painter.py
index 91ebe076104a..9d7b95b62c6e 100644
--- a/examples/community/hd_painter.py
+++ b/examples/community/hd_painter.py
@@ -686,7 +686,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 292c9aa2bc47..001e4cc5b2cf 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -362,7 +362,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index 129793dae6b0..814694f1e366 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -1120,7 +1120,7 @@ def latent_lmd_guidance(
if verbose:
logger.info(
- f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
+ f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
)
try:
@@ -1184,7 +1184,7 @@ def latent_lmd_guidance(
if verbose:
logger.info(
- f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
+ f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
)
finally:
diff --git a/examples/community/mod_controlnet_tile_sr_sdxl.py b/examples/community/mod_controlnet_tile_sr_sdxl.py
index 80bed2365d9f..3db2645a78a7 100644
--- a/examples/community/mod_controlnet_tile_sr_sdxl.py
+++ b/examples/community/mod_controlnet_tile_sr_sdxl.py
@@ -701,7 +701,7 @@ def check_inputs(
raise ValueError("`max_tile_size` cannot be None.")
elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
raise ValueError(
- f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}."
+ f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}."
)
if tile_gaussian_sigma is None:
raise ValueError("`tile_gaussian_sigma` cannot be None.")
diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py
index 9d6be763a0a0..5dc321ea98a2 100644
--- a/examples/community/pipeline_flux_differential_img2img.py
+++ b/examples/community/pipeline_flux_differential_img2img.py
@@ -488,7 +488,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -496,7 +496,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 736f00799eae..b9985542ccf7 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -907,12 +907,12 @@ def create_controller(
# reweight
if edit_type == "reweight":
- assert (
- equalizer_words is not None and equalizer_strengths is not None
- ), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
- assert len(equalizer_words) == len(
- equalizer_strengths
- ), "equalizer_words and equalizer_strengths must be of same length."
+ assert equalizer_words is not None and equalizer_strengths is not None, (
+ "To use reweight edit, please specify equalizer_words and equalizer_strengths."
+ )
+ assert len(equalizer_words) == len(equalizer_strengths), (
+ "equalizer_words and equalizer_strengths must be of same length."
+ )
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight(
prompts,
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index 9377caf7ba2e..6aebb6c18df7 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -1738,7 +1738,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
index 8a709ab46757..6c63f53e815c 100644
--- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
+++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -689,7 +689,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index 1269a69f0dc3..8459553f4e47 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -1028,7 +1028,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -1036,7 +1036,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -2050,7 +2050,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index 8480117866cc..6a0ed3523dab 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -1578,7 +1578,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/examples/community/scheduling_ufogen.py b/examples/community/scheduling_ufogen.py
index 4b1b92ff183a..0b832394cf97 100644
--- a/examples/community/scheduling_ufogen.py
+++ b/examples/community/scheduling_ufogen.py
@@ -288,8 +288,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 2045e7809310..28fc7c73e6eb 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index 38fe94ed3fe5..61d883fdfb78 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -901,7 +901,7 @@ def load_model_hook(models, input_dir):
unet_ = accelerator.unwrap_model(unet)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = {
- f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index fdb789c21628..4324f81b9695 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict
diff --git a/examples/custom_diffusion/retrieve.py b/examples/custom_diffusion/retrieve.py
index a28fe344d93b..27f4b4e0dc60 100644
--- a/examples/custom_diffusion/retrieve.py
+++ b/examples/custom_diffusion/retrieve.py
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
- with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
- f"{class_data_dir}/images.txt", "w"
- ) as f3:
+ with (
+ open(f"{class_data_dir}/caption.txt", "w") as f1,
+ open(f"{class_data_dir}/urls.txt", "w") as f2,
+ open(f"{class_data_dir}/images.txt", "w") as f3,
+ ):
while total < num_class_images:
images = class_images[count]
count += 1
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index ea1449f9f382..fa2959cf41a1 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior:
- assert (
- class_images_dir / "images"
- ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
- assert (
- len(list((class_images_dir / "images").iterdir())) == args.num_class_images
- ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
- assert (
- class_images_dir / "caption.txt"
- ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
- assert (
- class_images_dir / "images.txt"
- ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (class_images_dir / "images").exists(), (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
+ assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
+ assert (class_images_dir / "caption.txt").exists(), (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
+ assert (class_images_dir / "images.txt").exists(), (
+ f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
+ )
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index b863f5641233..43e680610ee5 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -1014,7 +1014,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 9584e7762dbd..7f8d06f34a35 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -982,7 +982,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index debdafd04ba1..febf7e51c6bd 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -1294,7 +1294,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py
index a8bf4e1cdc61..d2cedc248636 100644
--- a/examples/dreambooth/train_dreambooth_lora_lumina2.py
+++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py
@@ -1053,7 +1053,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index 674cb0d1ad1e..899b1ff679ab 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -1064,7 +1064,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 4a08daaf61f7..63cef5d17610 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -1355,7 +1355,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 735d48b83400..37241b8f9ef0 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
-# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
@@ -1286,7 +1286,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 56c5f2a89a3a..2a9bfd949cde 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype,
)
pipeline.load_lora_weights(args.output_dir)
- assert (
- pipeline.transformer.config.in_channels == initial_channels * 2
- ), f"{pipeline.transformer.config.in_channels=}"
+ assert pipeline.transformer.config.in_channels == initial_channels * 2, (
+ f"{pipeline.transformer.config.in_channels=}"
+ )
pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
@@ -954,7 +954,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = {
- f'{k.replace("transformer.", "")}': v
+ f"{k.replace('transformer.', '')}": v
for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k
}
diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py
index a8add8311006..b82e98fb71ff 100644
--- a/examples/model_search/pipeline_easy.py
+++ b/examples/model_search/pipeline_easy.py
@@ -1081,9 +1081,9 @@ def auto_load_textual_inversion(
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
)
- pretrained_model_name_or_paths[
- pretrained_model_name_or_paths.index(search_word)
- ] = textual_inversion_path.model_path
+ pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
+ textual_inversion_path.model_path
+ )
self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py
index 5c30b24efe88..2e96014c4193 100644
--- a/examples/research_projects/anytext/anytext.py
+++ b/examples/research_projects/anytext/anytext.py
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors="pt",
)
tokens = batch_encoding["input_ids"]
- assert (
- torch.count_nonzero(tokens - 49407) == 2
- ), f"String '{string}' maps to more than a single token. Please use another string"
+ assert torch.count_nonzero(tokens - 49407) == 2, (
+ f"String '{string}' maps to more than a single token. Please use another string"
+ )
return tokens[0, 1]
diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py
index 590a96995b26..3dc813b84a55 100644
--- a/examples/research_projects/anytext/ocr_recog/RecSVTR.py
+++ b/examples/research_projects/anytext/ocr_recog/RecSVTR.py
@@ -312,9 +312,9 @@ def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
def forward(self, x):
B, C, H, W = x.shape
- assert (
- H == self.img_size[0] and W == self.img_size[1]
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ assert H == self.img_size[0] and W == self.img_size[1], (
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ )
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
index 10c8e095a696..4e541b8d3a02 100644
--- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py
+++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
@@ -619,7 +619,7 @@ def collate_fn(examples):
optimizer.step()
lr_scheduler.step()
- logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
+ logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index 829b0031156e..9744bc7be200 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -803,21 +803,20 @@ def parse_args(input_args=None):
"--control_type",
type=str,
default="canny",
- help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
+ help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
)
parser.add_argument(
"--transformer_layers_per_block",
type=str,
default=None,
- help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
+ help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
)
parser.add_argument(
"--old_style_controlnet",
action="store_true",
default=False,
help=(
- "Use the old style controlnet, which is a single transformer layer with"
- " a single head. Defaults to False."
+ "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
),
)
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
index ab88d4967766..0b9c248ed004 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
index 0297a06f5b2c..f0afa12e9ceb 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index ed245e9cef7d..12eb67d4a7bb 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -683,7 +683,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index 66a7a3652947..a5d89f77d687 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
- logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation:
if args.mixed_precision == "fp16":
@@ -790,7 +790,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
index ccaf3164a00c..cc535bbaaa85 100644
--- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
+++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
@@ -783,7 +783,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
index bde093802a5d..aa5951723aaf 100644
--- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
+++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
@@ -1,3652 +1,3745 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "F88mignPnalS"
- },
- "source": [
- "# Introduction\n",
- "\n",
- "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n",
- "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n",
- "\n",
- "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n",
- "\n",
- "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n",
- "\n",
- "> Colab made by [natolambert](https://twitter.com/natolambert).\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7cnwXMocnuzB"
- },
- "source": [
- "## Installations\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Install Conda"
- ],
- "metadata": {
- "id": "ff9SxWnaNId9"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1g_6zOabItDk"
- },
- "source": [
- "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "K0ofXobG5Y-X",
- "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "nvcc: NVIDIA (R) Cuda compiler driver\n",
- "Copyright (c) 2005-2021 NVIDIA Corporation\n",
- "Built on Sun_Feb_14_21:12:58_PST_2021\n",
- "Cuda compilation tools, release 11.2, V11.2.152\n",
- "Build cuda_11.2.r11.2/compiler.29618528_0\n"
- ]
- }
- ],
- "source": [
- "!nvcc --version"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "VfthW90vI0nw"
- },
- "source": [
- "Install Conda for some more complex dependencies for geometric networks."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "2WNFzSnbiE0k",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- }
- ],
- "source": [
- "!pip install -q condacolab"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NUsbWYCUI7Km"
- },
- "source": [
- "Setup Conda"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "FZelreINdmd0",
- "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "✨🍰✨ Everything looks OK!\n"
- ]
- }
- ],
- "source": [
- "import condacolab\n",
- "condacolab.install()"
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "F88mignPnalS"
+ },
+ "source": [
+ "# Introduction\n",
+ "\n",
+ "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n",
+ "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n",
+ "\n",
+ "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n",
+ "\n",
+ "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n",
+ "\n",
+ "> Colab made by [natolambert](https://twitter.com/natolambert).\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7cnwXMocnuzB"
+ },
+ "source": [
+ "## Installations\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ff9SxWnaNId9"
+ },
+ "source": [
+ "### Install Conda"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1g_6zOabItDk"
+ },
+ "source": [
+ "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "K0ofXobG5Y-X",
+ "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "nvcc: NVIDIA (R) Cuda compiler driver\n",
+ "Copyright (c) 2005-2021 NVIDIA Corporation\n",
+ "Built on Sun_Feb_14_21:12:58_PST_2021\n",
+ "Cuda compilation tools, release 11.2, V11.2.152\n",
+ "Build cuda_11.2.r11.2/compiler.29618528_0\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvcc --version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VfthW90vI0nw"
+ },
+ "source": [
+ "Install Conda for some more complex dependencies for geometric networks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2WNFzSnbiE0k",
+ "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -q condacolab"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NUsbWYCUI7Km"
+ },
+ "source": [
+ "Setup Conda"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FZelreINdmd0",
+ "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✨🍰✨ Everything looks OK!\n"
+ ]
+ }
+ ],
+ "source": [
+ "import condacolab\n",
+ "\n",
+ "\n",
+ "condacolab.install()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JzDHaPU7I9Sn"
+ },
+ "source": [
+ "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JMxRjHhL7w8V",
+ "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n",
+ "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "\n",
+ "## Package Plan ##\n",
+ "\n",
+ " environment location: /usr/local\n",
+ "\n",
+ " added / updated specs:\n",
+ " - cudatoolkit=11.1\n",
+ " - pytorch\n",
+ " - torchaudio\n",
+ " - torchvision\n",
+ "\n",
+ "\n",
+ "The following packages will be downloaded:\n",
+ "\n",
+ " package | build\n",
+ " ---------------------------|-----------------\n",
+ " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n",
+ " ------------------------------------------------------------\n",
+ " Total: 960 KB\n",
+ "\n",
+ "The following packages will be UPDATED:\n",
+ "\n",
+ " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n",
+ "\n",
+ "\n",
+ "\n",
+ "Downloading and Extracting Packages\n",
+ "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n",
+ "Preparing transaction: / \b\bdone\n",
+ "Verifying transaction: \\ \b\bdone\n",
+ "Executing transaction: / \b\bdone\n",
+ "Retrieving notices: ...working... done\n"
+ ]
+ }
+ ],
+ "source": [
+ "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n",
+ "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QDS6FPZ0Tu5b"
+ },
+ "source": [
+ "Need to remove a pathspec for colab that specifies the incorrect cuda version."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dq1lxR10TtrR",
+ "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n"
+ ]
+ }
+ ],
+ "source": [
+ "!rm /usr/local/conda-meta/pinned"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Z1L3DdZOJB30"
+ },
+ "source": [
+ "Install torch geometric (used in the model later)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "D5ukfCOWfjzK",
+ "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
+ "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "\n",
+ "## Package Plan ##\n",
+ "\n",
+ " environment location: /usr/local\n",
+ "\n",
+ " added / updated specs:\n",
+ " - pytorch-geometric=1.7.2\n",
+ "\n",
+ "\n",
+ "The following packages will be downloaded:\n",
+ "\n",
+ " package | build\n",
+ " ---------------------------|-----------------\n",
+ " decorator-4.4.2 | py_0 11 KB conda-forge\n",
+ " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n",
+ " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n",
+ " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n",
+ " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n",
+ " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n",
+ " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n",
+ " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n",
+ " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n",
+ " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n",
+ " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n",
+ " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n",
+ " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n",
+ " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n",
+ " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n",
+ " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n",
+ " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n",
+ " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n",
+ " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n",
+ " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n",
+ " ------------------------------------------------------------\n",
+ " Total: 55.9 MB\n",
+ "\n",
+ "The following NEW packages will be INSTALLED:\n",
+ "\n",
+ " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n",
+ " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n",
+ " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n",
+ " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n",
+ " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n",
+ " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n",
+ " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n",
+ " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n",
+ " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n",
+ " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n",
+ " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n",
+ " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n",
+ " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n",
+ " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n",
+ " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n",
+ "\n",
+ "The following packages will be DOWNGRADED:\n",
+ "\n",
+ " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n",
+ "\n",
+ "\n",
+ "\n",
+ "Downloading and Extracting Packages\n",
+ "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n",
+ "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n",
+ "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n",
+ "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n",
+ "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n",
+ "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n",
+ "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n",
+ "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n",
+ "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n",
+ "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n",
+ "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n",
+ "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n",
+ "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n",
+ "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n",
+ "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n",
+ "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n",
+ "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n",
+ "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n",
+ "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n",
+ "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n",
+ "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n",
+ "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
+ "Retrieving notices: ...working... done\n"
+ ]
+ }
+ ],
+ "source": [
+ "!conda install -c rusty1s pytorch-geometric=1.7.2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ppxv6Mdkalbc"
+ },
+ "source": [
+ "### Install Diffusers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "mgQA_XN-XGY2",
+ "outputId": "85392615-b6a4-4052-9d2a-79604be62c94"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/content\n",
+ "Cloning into 'diffusers'...\n",
+ "remote: Enumerating objects: 9298, done.\u001b[K\n",
+ "remote: Counting objects: 100% (40/40), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
+ "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n",
+ "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n",
+ "Resolving deltas: 100% (6168/6168), done.\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "%cd /content\n",
+ "\n",
+ "# install latest HF diffusers (will update to the release once added)\n",
+ "!git clone https://github.com/huggingface/diffusers.git\n",
+ "!pip install -q /content/diffusers\n",
+ "\n",
+ "# dependencies for diffusers\n",
+ "!pip install -q datasets transformers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LZO6AJKuJKO8"
+ },
+ "source": [
+ "Check that torch is installed correctly and utilizing the GPU in the colab"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 53
},
+ "id": "gZt7BNi1e1PA",
+ "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "JzDHaPU7I9Sn"
- },
- "source": [
- "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "True\n"
+ ]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "JMxRjHhL7w8V",
- "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8"
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n",
- "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
- "\n",
- "## Package Plan ##\n",
- "\n",
- " environment location: /usr/local\n",
- "\n",
- " added / updated specs:\n",
- " - cudatoolkit=11.1\n",
- " - pytorch\n",
- " - torchaudio\n",
- " - torchvision\n",
- "\n",
- "\n",
- "The following packages will be downloaded:\n",
- "\n",
- " package | build\n",
- " ---------------------------|-----------------\n",
- " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n",
- " ------------------------------------------------------------\n",
- " Total: 960 KB\n",
- "\n",
- "The following packages will be UPDATED:\n",
- "\n",
- " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n",
- "\n",
- "\n",
- "\n",
- "Downloading and Extracting Packages\n",
- "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n",
- "Preparing transaction: / \b\bdone\n",
- "Verifying transaction: \\ \b\bdone\n",
- "Executing transaction: / \b\bdone\n",
- "Retrieving notices: ...working... done\n"
- ]
- }
- ],
- "source": [
- "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n",
- "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge"
+ "text/plain": [
+ "'1.8.2'"
]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Need to remove a pathspec for colab that specifies the incorrect cuda version."
- ],
- "metadata": {
- "id": "QDS6FPZ0Tu5b"
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "\n",
+ "print(torch.cuda.is_available())\n",
+ "torch.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KLE7CqlfJNUO"
+ },
+ "source": [
+ "### Install Chemistry-specific Dependencies\n",
+ "\n",
+ "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0CPv_NvehRz3",
+ "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting rdkit\n",
+ " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n",
+ "Installing collected packages: rdkit\n",
+ "Successfully installed rdkit-2022.3.5\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install rdkit"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "88GaDbDPxJ5I"
+ },
+ "source": [
+ "### Get viewer from nglview\n",
+ "\n",
+ "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n",
+ "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n",
+ "The rdmol in this object is a source of ground truth for the generated molecules.\n",
+ "\n",
+ "You will use one rendering function from nglviewer later!\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "jcl8GCS2mz6t",
+ "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting nglview\n",
+ " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n",
+ "Collecting jupyterlab-widgets\n",
+ " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipywidgets>=7\n",
+ " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting widgetsnbextension~=4.0\n",
+ " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipython>=6.1.0\n",
+ " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipykernel>=4.5.1\n",
+ " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting traitlets>=4.3.1\n",
+ " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n",
+ "Collecting pyzmq>=17\n",
+ " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting matplotlib-inline>=0.1\n",
+ " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n",
+ "Collecting tornado>=6.1\n",
+ " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nest-asyncio\n",
+ " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n",
+ "Collecting debugpy>=1.0\n",
+ " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting psutil\n",
+ " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jupyter-client>=6.1.12\n",
+ " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting pickleshare\n",
+ " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n",
+ "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n",
+ "Collecting backcall\n",
+ " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n",
+ "Collecting pexpect>4.3\n",
+ " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting pygments\n",
+ " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jedi>=0.16\n",
+ " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n",
+ " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n",
+ "Collecting parso<0.9.0,>=0.8.0\n",
+ " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n",
+ "Collecting entrypoints\n",
+ " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n",
+ "Collecting jupyter-core>=4.9.2\n",
+ " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ptyprocess>=0.5\n",
+ " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n",
+ "Collecting wcwidth\n",
+ " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n",
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n",
+ "Building wheels for collected packages: nglview\n",
+ " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n",
+ " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n",
+ "Successfully built nglview\n",
+ "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n",
+ "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "pexpect",
+ "pickleshare",
+ "wcwidth"
+ ]
+ }
}
- },
- {
- "cell_type": "code",
- "source": [
- "!rm /usr/local/conda-meta/pinned"
- ],
- "metadata": {
- "id": "dq1lxR10TtrR",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8"
- },
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Z1L3DdZOJB30"
- },
- "source": [
- "Install torch geometric (used in the model later)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "D5ukfCOWfjzK",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
- "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
- "\n",
- "## Package Plan ##\n",
- "\n",
- " environment location: /usr/local\n",
- "\n",
- " added / updated specs:\n",
- " - pytorch-geometric=1.7.2\n",
- "\n",
- "\n",
- "The following packages will be downloaded:\n",
- "\n",
- " package | build\n",
- " ---------------------------|-----------------\n",
- " decorator-4.4.2 | py_0 11 KB conda-forge\n",
- " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n",
- " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n",
- " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n",
- " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n",
- " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n",
- " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n",
- " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n",
- " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n",
- " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n",
- " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n",
- " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n",
- " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n",
- " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n",
- " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n",
- " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n",
- " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n",
- " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n",
- " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n",
- " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n",
- " ------------------------------------------------------------\n",
- " Total: 55.9 MB\n",
- "\n",
- "The following NEW packages will be INSTALLED:\n",
- "\n",
- " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n",
- " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n",
- " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n",
- " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n",
- " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n",
- " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n",
- " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n",
- " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n",
- " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n",
- " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n",
- " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n",
- " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n",
- " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n",
- " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n",
- " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n",
- " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n",
- " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n",
- " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n",
- " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n",
- "\n",
- "The following packages will be DOWNGRADED:\n",
- "\n",
- " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n",
- "\n",
- "\n",
- "\n",
- "Downloading and Extracting Packages\n",
- "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n",
- "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n",
- "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n",
- "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n",
- "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n",
- "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n",
- "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n",
- "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n",
- "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n",
- "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n",
- "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n",
- "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n",
- "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n",
- "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n",
- "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n",
- "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n",
- "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n",
- "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n",
- "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n",
- "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n",
- "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n",
- "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
- "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
- "Retrieving notices: ...working... done\n"
- ]
- }
- ],
- "source": [
- "!conda install -c rusty1s pytorch-geometric=1.7.2"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ppxv6Mdkalbc"
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "!pip install nglview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8t8_e_uVLdKB"
+ },
+ "source": [
+ "## Create a diffusion model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G0rMncVtNSqU"
+ },
+ "source": [
+ "### Model class(es)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "L5FEXz5oXkzt"
+ },
+ "source": [
+ "Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-3-P4w5sXkRU"
+ },
+ "outputs": [],
+ "source": [
+ "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n",
+ "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n",
+ "from dataclasses import dataclass\n",
+ "from typing import Callable, Tuple, Union\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch import Tensor, nn\n",
+ "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n",
+ "from torch_geometric.nn import MessagePassing, radius, radius_graph\n",
+ "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n",
+ "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n",
+ "from torch_scatter import scatter_add\n",
+ "from torch_sparse import SparseTensor, coalesce\n",
+ "\n",
+ "from diffusers.configuration_utils import ConfigMixin, register_to_config\n",
+ "from diffusers.modeling_utils import ModelMixin\n",
+ "from diffusers.utils import BaseOutput"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EzJQXPN_XrMX"
+ },
+ "source": [
+ "Helper classes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "oR1Y56QiLY90"
+ },
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class MoleculeGNNOutput(BaseOutput):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
+ " Hidden states output. Output of last layer of model.\n",
+ " \"\"\"\n",
+ "\n",
+ " sample: torch.Tensor\n",
+ "\n",
+ "\n",
+ "class MultiLayerPerceptron(nn.Module):\n",
+ " \"\"\"\n",
+ " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n",
+ " Args:\n",
+ " input_dim (int): input dimension\n",
+ " hidden_dim (list of int): hidden dimensions\n",
+ " activation (str or function, optional): activation function\n",
+ " dropout (float, optional): dropout rate\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n",
+ " super(MultiLayerPerceptron, self).__init__()\n",
+ "\n",
+ " self.dims = [input_dim] + hidden_dims\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " print(f\"Warning, activation passed {activation} is not string and ignored\")\n",
+ " self.activation = None\n",
+ " if dropout > 0:\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " else:\n",
+ " self.dropout = None\n",
+ "\n",
+ " self.layers = nn.ModuleList()\n",
+ " for i in range(len(self.dims) - 1):\n",
+ " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " \"\"\"\"\"\"\n",
+ " for i, layer in enumerate(self.layers):\n",
+ " x = layer(x)\n",
+ " if i < len(self.layers) - 1:\n",
+ " if self.activation:\n",
+ " x = self.activation(x)\n",
+ " if self.dropout:\n",
+ " x = self.dropout(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class ShiftedSoftplus(torch.nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ShiftedSoftplus, self).__init__()\n",
+ " self.shift = torch.log(torch.tensor(2.0)).item()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return F.softplus(x) - self.shift\n",
+ "\n",
+ "\n",
+ "class CFConv(MessagePassing):\n",
+ " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n",
+ " super(CFConv, self).__init__(aggr=\"add\")\n",
+ " self.lin1 = Linear(in_channels, num_filters, bias=False)\n",
+ " self.lin2 = Linear(num_filters, out_channels)\n",
+ " self.nn = mlp\n",
+ " self.cutoff = cutoff\n",
+ " self.smooth = smooth\n",
+ "\n",
+ " self.reset_parameters()\n",
+ "\n",
+ " def reset_parameters(self):\n",
+ " torch.nn.init.xavier_uniform_(self.lin1.weight)\n",
+ " torch.nn.init.xavier_uniform_(self.lin2.weight)\n",
+ " self.lin2.bias.data.fill_(0)\n",
+ "\n",
+ " def forward(self, x, edge_index, edge_length, edge_attr):\n",
+ " if self.smooth:\n",
+ " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n",
+ " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n",
+ " else:\n",
+ " C = (edge_length <= self.cutoff).float()\n",
+ " W = self.nn(edge_attr) * C.view(-1, 1)\n",
+ "\n",
+ " x = self.lin1(x)\n",
+ " x = self.propagate(edge_index, x=x, W=W)\n",
+ " x = self.lin2(x)\n",
+ " return x\n",
+ "\n",
+ " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n",
+ " return x_j * W\n",
+ "\n",
+ "\n",
+ "class InteractionBlock(torch.nn.Module):\n",
+ " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n",
+ " super(InteractionBlock, self).__init__()\n",
+ " mlp = Sequential(\n",
+ " Linear(num_gaussians, num_filters),\n",
+ " ShiftedSoftplus(),\n",
+ " Linear(num_filters, num_filters),\n",
+ " )\n",
+ " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n",
+ " self.act = ShiftedSoftplus()\n",
+ " self.lin = Linear(hidden_channels, hidden_channels)\n",
+ "\n",
+ " def forward(self, x, edge_index, edge_length, edge_attr):\n",
+ " x = self.conv(x, edge_index, edge_length, edge_attr)\n",
+ " x = self.act(x)\n",
+ " x = self.lin(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class SchNetEncoder(Module):\n",
+ " def __init__(\n",
+ " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n",
+ " ):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.hidden_channels = hidden_channels\n",
+ " self.num_filters = num_filters\n",
+ " self.num_interactions = num_interactions\n",
+ " self.cutoff = cutoff\n",
+ "\n",
+ " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n",
+ "\n",
+ " self.interactions = ModuleList()\n",
+ " for _ in range(num_interactions):\n",
+ " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n",
+ " self.interactions.append(block)\n",
+ "\n",
+ " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n",
+ " if embed_node:\n",
+ " assert z.dim() == 1 and z.dtype == torch.long\n",
+ " h = self.embedding(z)\n",
+ " else:\n",
+ " h = z\n",
+ " for interaction in self.interactions:\n",
+ " h = h + interaction(h, edge_index, edge_length, edge_attr)\n",
+ "\n",
+ " return h\n",
+ "\n",
+ "\n",
+ "class GINEConv(MessagePassing):\n",
+ " \"\"\"\n",
+ " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n",
+ " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n",
+ " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n",
+ " self.nn = mlp\n",
+ " self.initial_eps = eps\n",
+ "\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " self.activation = None\n",
+ "\n",
+ " if train_eps:\n",
+ " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n",
+ " else:\n",
+ " self.register_buffer(\"eps\", torch.Tensor([eps]))\n",
+ "\n",
+ " def forward(\n",
+ " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n",
+ " ) -> torch.Tensor:\n",
+ " \"\"\"\"\"\"\n",
+ " if isinstance(x, torch.Tensor):\n",
+ " x: OptPairTensor = (x, x)\n",
+ "\n",
+ " # Node and edge feature dimensionalites need to match.\n",
+ " if isinstance(edge_index, torch.Tensor):\n",
+ " assert edge_attr is not None\n",
+ " assert x[0].size(-1) == edge_attr.size(-1)\n",
+ " elif isinstance(edge_index, SparseTensor):\n",
+ " assert x[0].size(-1) == edge_index.size(-1)\n",
+ "\n",
+ " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n",
+ " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n",
+ "\n",
+ " x_r = x[1]\n",
+ " if x_r is not None:\n",
+ " out += (1 + self.eps) * x_r\n",
+ "\n",
+ " return self.nn(out)\n",
+ "\n",
+ " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n",
+ " if self.activation:\n",
+ " return self.activation(x_j + edge_attr)\n",
+ " else:\n",
+ " return x_j + edge_attr\n",
+ "\n",
+ " def __repr__(self):\n",
+ " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n",
+ "\n",
+ "\n",
+ "class GINEncoder(torch.nn.Module):\n",
+ " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.num_convs = num_convs\n",
+ " self.short_cut = short_cut\n",
+ " self.concat_hidden = concat_hidden\n",
+ " self.node_emb = nn.Embedding(100, hidden_dim)\n",
+ "\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " self.activation = None\n",
+ "\n",
+ " self.convs = nn.ModuleList()\n",
+ " for i in range(self.num_convs):\n",
+ " self.convs.append(\n",
+ " GINEConv(\n",
+ " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n",
+ " activation=activation,\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " def forward(self, z, edge_index, edge_attr):\n",
+ " \"\"\"\n",
+ " Input:\n",
+ " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n",
+ " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n",
+ " Output:\n",
+ " node_feature: graph feature\n",
+ " \"\"\"\n",
+ "\n",
+ " node_attr = self.node_emb(z) # (num_node, hidden)\n",
+ "\n",
+ " hiddens = []\n",
+ " conv_input = node_attr # (num_node, hidden)\n",
+ "\n",
+ " for conv_idx, conv in enumerate(self.convs):\n",
+ " hidden = conv(conv_input, edge_index, edge_attr)\n",
+ " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n",
+ " hidden = self.activation(hidden)\n",
+ " assert hidden.shape == conv_input.shape\n",
+ " if self.short_cut and hidden.shape == conv_input.shape:\n",
+ " hidden += conv_input\n",
+ "\n",
+ " hiddens.append(hidden)\n",
+ " conv_input = hidden\n",
+ "\n",
+ " if self.concat_hidden:\n",
+ " node_feature = torch.cat(hiddens, dim=-1)\n",
+ " else:\n",
+ " node_feature = hiddens[-1]\n",
+ "\n",
+ " return node_feature\n",
+ "\n",
+ "\n",
+ "class MLPEdgeEncoder(Module):\n",
+ " def __init__(self, hidden_dim=100, activation=\"relu\"):\n",
+ " super().__init__()\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n",
+ " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n",
+ "\n",
+ " @property\n",
+ " def out_channels(self):\n",
+ " return self.hidden_dim\n",
+ "\n",
+ " def forward(self, edge_length, edge_type):\n",
+ " \"\"\"\n",
+ " Input:\n",
+ " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n",
+ " Returns:\n",
+ " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n",
+ " \"\"\"\n",
+ " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n",
+ " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n",
+ " return d_emb * edge_attr # (num_edge, hidden)\n",
+ "\n",
+ "\n",
+ "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n",
+ " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n",
+ " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n",
+ " return h_pair\n",
+ "\n",
+ "\n",
+ "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " num_nodes: Number of atoms.\n",
+ " edge_index: Bond indices of the original graph.\n",
+ " edge_type: Bond types of the original graph.\n",
+ " order: Extension order.\n",
+ " Returns:\n",
+ " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n",
+ " \"\"\"\n",
+ "\n",
+ " def binarize(x):\n",
+ " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n",
+ "\n",
+ " def get_higher_order_adj_matrix(adj, order):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " adj: (N, N)\n",
+ " type_mat: (N, N)\n",
+ " Returns:\n",
+ " Following attributes will be updated:\n",
+ " - edge_index\n",
+ " - edge_type\n",
+ " Following attributes will be added to the data object:\n",
+ " - bond_edge_index: Original edge_index.\n",
+ " \"\"\"\n",
+ " adj_mats = [\n",
+ " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n",
+ " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n",
+ " ]\n",
+ "\n",
+ " for i in range(2, order + 1):\n",
+ " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n",
+ " order_mat = torch.zeros_like(adj)\n",
+ "\n",
+ " for i in range(1, order + 1):\n",
+ " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n",
+ "\n",
+ " return order_mat\n",
+ "\n",
+ " num_types = 22\n",
+ " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n",
+ " # from rdkit.Chem.rdchem import BondType as BT\n",
+ " N = num_nodes\n",
+ " adj = to_dense_adj(edge_index).squeeze(0)\n",
+ " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n",
+ "\n",
+ " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n",
+ " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n",
+ " assert (type_mat * type_highorder == 0).all()\n",
+ " type_new = type_mat + type_highorder\n",
+ "\n",
+ " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n",
+ " _, edge_order = dense_to_sparse(adj_order)\n",
+ "\n",
+ " # data.bond_edge_index = data.edge_index # Save original edges\n",
+ " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n",
+ "\n",
+ " return new_edge_index, new_edge_type\n",
+ "\n",
+ "\n",
+ "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n",
+ " assert edge_type.dim() == 1\n",
+ " N = pos.size(0)\n",
+ "\n",
+ " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n",
+ "\n",
+ " if is_sidechain is None:\n",
+ " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n",
+ " else:\n",
+ " # fetch sidechain and its batch index\n",
+ " is_sidechain = is_sidechain.bool()\n",
+ " dummy_index = torch.arange(pos.size(0), device=pos.device)\n",
+ " sidechain_pos = pos[is_sidechain]\n",
+ " sidechain_index = dummy_index[is_sidechain]\n",
+ " sidechain_batch = batch[is_sidechain]\n",
+ "\n",
+ " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n",
+ " r_edge_index_x = assign_index[1]\n",
+ " r_edge_index_y = assign_index[0]\n",
+ " r_edge_index_y = sidechain_index[r_edge_index_y]\n",
+ "\n",
+ " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n",
+ " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n",
+ " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n",
+ " # delete self loop\n",
+ " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n",
+ "\n",
+ " rgraph_adj = torch.sparse.LongTensor(\n",
+ " rgraph_edge_index,\n",
+ " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n",
+ " torch.Size([N, N]),\n",
+ " )\n",
+ "\n",
+ " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n",
+ "\n",
+ " new_edge_index = composed_adj.indices()\n",
+ " new_edge_type = composed_adj.values().long()\n",
+ "\n",
+ " return new_edge_index, new_edge_type\n",
+ "\n",
+ "\n",
+ "def extend_graph_order_radius(\n",
+ " num_nodes,\n",
+ " pos,\n",
+ " edge_index,\n",
+ " edge_type,\n",
+ " batch,\n",
+ " order=3,\n",
+ " cutoff=10.0,\n",
+ " extend_order=True,\n",
+ " extend_radius=True,\n",
+ " is_sidechain=None,\n",
+ "):\n",
+ " if extend_order:\n",
+ " edge_index, edge_type = _extend_graph_order(\n",
+ " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n",
+ " )\n",
+ "\n",
+ " if extend_radius:\n",
+ " edge_index, edge_type = _extend_to_radius_graph(\n",
+ " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n",
+ " )\n",
+ "\n",
+ " return edge_index, edge_type\n",
+ "\n",
+ "\n",
+ "def get_distance(pos, edge_index):\n",
+ " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n",
+ "\n",
+ "\n",
+ "def graph_field_network(score_d, pos, edge_index, edge_length):\n",
+ " \"\"\"\n",
+ " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n",
+ " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n",
+ " \"\"\"\n",
+ " N = pos.size(0)\n",
+ " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n",
+ " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n",
+ " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n",
+ " ) # (N, 3)\n",
+ " return score_pos\n",
+ "\n",
+ "\n",
+ "def clip_norm(vec, limit, p=2):\n",
+ " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n",
+ " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n",
+ " return vec * denom\n",
+ "\n",
+ "\n",
+ "def is_local_edge(edge_type):\n",
+ " return edge_type > 0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QWrHJFcYXyUB"
+ },
+ "source": [
+ "Main model class!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MCeZA1qQXzoK"
+ },
+ "outputs": [],
+ "source": [
+ "class MoleculeGNN(ModelMixin, ConfigMixin):\n",
+ " @register_to_config\n",
+ " def __init__(\n",
+ " self,\n",
+ " hidden_dim=128,\n",
+ " num_convs=6,\n",
+ " num_convs_local=4,\n",
+ " cutoff=10.0,\n",
+ " mlp_act=\"relu\",\n",
+ " edge_order=3,\n",
+ " edge_encoder=\"mlp\",\n",
+ " smooth_conv=True,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.cutoff = cutoff\n",
+ " self.edge_encoder = edge_encoder\n",
+ " self.edge_order = edge_order\n",
+ "\n",
+ " \"\"\"\n",
+ " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n",
+ " in SchNetEncoder\n",
+ " \"\"\"\n",
+ " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
+ " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
+ "\n",
+ " \"\"\"\n",
+ " The graph neural network that extracts node-wise features.\n",
+ " \"\"\"\n",
+ " self.encoder_global = SchNetEncoder(\n",
+ " hidden_channels=hidden_dim,\n",
+ " num_filters=hidden_dim,\n",
+ " num_interactions=num_convs,\n",
+ " edge_channels=self.edge_encoder_global.out_channels,\n",
+ " cutoff=cutoff,\n",
+ " smooth=smooth_conv,\n",
+ " )\n",
+ " self.encoder_local = GINEncoder(\n",
+ " hidden_dim=hidden_dim,\n",
+ " num_convs=num_convs_local,\n",
+ " )\n",
+ "\n",
+ " \"\"\"\n",
+ " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n",
+ " gradients w.r.t. edge_length (out_dim = 1).\n",
+ " \"\"\"\n",
+ " self.grad_global_dist_mlp = MultiLayerPerceptron(\n",
+ " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
+ " )\n",
+ "\n",
+ " self.grad_local_dist_mlp = MultiLayerPerceptron(\n",
+ " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
+ " )\n",
+ "\n",
+ " \"\"\"\n",
+ " Incorporate parameters together\n",
+ " \"\"\"\n",
+ " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n",
+ " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n",
+ "\n",
+ " def _forward(\n",
+ " self,\n",
+ " atom_type,\n",
+ " pos,\n",
+ " bond_index,\n",
+ " bond_type,\n",
+ " batch,\n",
+ " time_step, # NOTE, model trained without timestep performed best\n",
+ " edge_index=None,\n",
+ " edge_type=None,\n",
+ " edge_length=None,\n",
+ " return_edges=False,\n",
+ " extend_order=True,\n",
+ " extend_radius=True,\n",
+ " is_sidechain=None,\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " atom_type: Types of atoms, (N, ).\n",
+ " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n",
+ " bond_type: Bond types, (E, ).\n",
+ " batch: Node index to graph index, (N, ).\n",
+ " \"\"\"\n",
+ " N = atom_type.size(0)\n",
+ " if edge_index is None or edge_type is None or edge_length is None:\n",
+ " edge_index, edge_type = extend_graph_order_radius(\n",
+ " num_nodes=N,\n",
+ " pos=pos,\n",
+ " edge_index=bond_index,\n",
+ " edge_type=bond_type,\n",
+ " batch=batch,\n",
+ " order=self.edge_order,\n",
+ " cutoff=self.cutoff,\n",
+ " extend_order=extend_order,\n",
+ " extend_radius=extend_radius,\n",
+ " is_sidechain=is_sidechain,\n",
+ " )\n",
+ " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n",
+ " local_edge_mask = is_local_edge(edge_type) # (E, )\n",
+ "\n",
+ " # with the parameterization of NCSNv2\n",
+ " # DDPM loss implicit handle the noise variance scale conditioning\n",
+ " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n",
+ "\n",
+ " # Encoding global\n",
+ " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
+ "\n",
+ " # Global\n",
+ " node_attr_global = self.encoder_global(\n",
+ " z=atom_type,\n",
+ " edge_index=edge_index,\n",
+ " edge_length=edge_length,\n",
+ " edge_attr=edge_attr_global,\n",
+ " )\n",
+ " # Assemble pairwise features\n",
+ " h_pair_global = assemble_atom_pair_feature(\n",
+ " node_attr=node_attr_global,\n",
+ " edge_index=edge_index,\n",
+ " edge_attr=edge_attr_global,\n",
+ " ) # (E_global, 2H)\n",
+ " # Invariant features of edges (radius graph, global)\n",
+ " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n",
+ "\n",
+ " # Encoding local\n",
+ " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
+ " # edge_attr += temb_edge\n",
+ "\n",
+ " # Local\n",
+ " node_attr_local = self.encoder_local(\n",
+ " z=atom_type,\n",
+ " edge_index=edge_index[:, local_edge_mask],\n",
+ " edge_attr=edge_attr_local[local_edge_mask],\n",
+ " )\n",
+ " # Assemble pairwise features\n",
+ " h_pair_local = assemble_atom_pair_feature(\n",
+ " node_attr=node_attr_local,\n",
+ " edge_index=edge_index[:, local_edge_mask],\n",
+ " edge_attr=edge_attr_local[local_edge_mask],\n",
+ " ) # (E_local, 2H)\n",
+ "\n",
+ " # Invariant features of edges (bond graph, local)\n",
+ " if isinstance(sigma_edge, torch.Tensor):\n",
+ " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n",
+ " 1.0 / sigma_edge[local_edge_mask]\n",
+ " ) # (E_local, 1)\n",
+ " else:\n",
+ " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n",
+ "\n",
+ " if return_edges:\n",
+ " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n",
+ " else:\n",
+ " return edge_inv_global, edge_inv_local\n",
+ "\n",
+ " def forward(\n",
+ " self,\n",
+ " sample,\n",
+ " timestep: Union[torch.Tensor, float, int],\n",
+ " return_dict: bool = True,\n",
+ " sigma=1.0,\n",
+ " global_start_sigma=0.5,\n",
+ " w_global=1.0,\n",
+ " extend_order=False,\n",
+ " extend_radius=True,\n",
+ " clip_local=None,\n",
+ " clip_global=1000.0,\n",
+ " ) -> Union[MoleculeGNNOutput, Tuple]:\n",
+ " r\"\"\"\n",
+ " Args:\n",
+ " sample: packed torch geometric object\n",
+ " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
+ " return_dict (`bool`, *optional*, defaults to `True`):\n",
+ " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
+ " Returns:\n",
+ " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n",
+ " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n",
+ " \"\"\"\n",
+ "\n",
+ " # unpack sample\n",
+ " atom_type = sample.atom_type\n",
+ " bond_index = sample.edge_index\n",
+ " bond_type = sample.edge_type\n",
+ " num_graphs = sample.num_graphs\n",
+ " pos = sample.pos\n",
+ "\n",
+ " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n",
+ "\n",
+ " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n",
+ " atom_type=atom_type,\n",
+ " pos=sample.pos,\n",
+ " bond_index=bond_index,\n",
+ " bond_type=bond_type,\n",
+ " batch=sample.batch,\n",
+ " time_step=timesteps,\n",
+ " return_edges=True,\n",
+ " extend_order=extend_order,\n",
+ " extend_radius=extend_radius,\n",
+ " ) # (E_global, 1), (E_local, 1)\n",
+ "\n",
+ " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n",
+ " node_eq_local = graph_field_network(\n",
+ " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n",
+ " )\n",
+ " if clip_local is not None:\n",
+ " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n",
+ "\n",
+ " # Global\n",
+ " if sigma < global_start_sigma:\n",
+ " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n",
+ " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n",
+ " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n",
+ " else:\n",
+ " node_eq_global = 0\n",
+ "\n",
+ " # Sum\n",
+ " eps_pos = node_eq_local + node_eq_global * w_global\n",
+ "\n",
+ " if not return_dict:\n",
+ " return (-eps_pos,)\n",
+ "\n",
+ " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CCIrPYSJj9wd"
+ },
+ "source": [
+ "### Load pretrained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YdrAr6Ch--Ab"
+ },
+ "source": [
+ "#### Load a model\n",
+ "The model used is a design an\n",
+ "equivariant convolutional layer, named graph field network (GFN).\n",
+ "\n",
+ "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 172,
+ "referenced_widgets": [
+ "d90f304e9560472eacfbdd11e46765eb",
+ "1c6246f15b654f4daa11c9bcf997b78c",
+ "c2321b3bff6f490ca12040a20308f555",
+ "b7feb522161f4cf4b7cc7c1a078ff12d",
+ "e2d368556e494ae7ae4e2e992af2cd4f",
+ "bbef741e76ec41b7ab7187b487a383df",
+ "561f742d418d4721b0670cc8dd62e22c",
+ "872915dd1bb84f538c44e26badabafdd",
+ "d022575f1fa2446d891650897f187b4d",
+ "fdc393f3468c432aa0ada05e238a5436",
+ "2c9362906e4b40189f16d14aa9a348da",
+ "6010fc8daa7a44d5aec4b830ec2ebaa1",
+ "7e0bb1b8d65249d3974200686b193be2",
+ "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
+ "6526646be5ed415c84d1245b040e629b",
+ "24d31fc3576e43dd9f8301d2ef3a37ab",
+ "2918bfaadc8d4b1a9832522c40dfefb8",
+ "a4bfdca35cc54dae8812720f1b276a08",
+ "e4901541199b45c6a18824627692fc39",
+ "f915cf874246446595206221e900b2fe",
+ "a9e388f22a9742aaaf538e22575c9433",
+ "42f6c3db29d7484ba6b4f73590abd2f4"
+ ]
+ },
+ "id": "DyCo0nsqjbml",
+ "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d90f304e9560472eacfbdd11e46765eb",
+ "version_major": 2,
+ "version_minor": 0
},
- "source": [
- "### Install Diffusers"
+ "text/plain": [
+ "Downloading: 0%| | 0.00/3.27M [00:00, ?B/s]"
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "mgQA_XN-XGY2",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "85392615-b6a4-4052-9d2a-79604be62c94"
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6010fc8daa7a44d5aec4b830ec2ebaa1",
+ "version_major": 2,
+ "version_minor": 0
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "/content\n",
- "Cloning into 'diffusers'...\n",
- "remote: Enumerating objects: 9298, done.\u001b[K\n",
- "remote: Counting objects: 100% (40/40), done.\u001b[K\n",
- "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
- "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n",
- "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n",
- "Resolving deltas: 100% (6168/6168), done.\n",
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- }
- ],
- "source": [
- "%cd /content\n",
- "\n",
- "# install latest HF diffusers (will update to the release once added)\n",
- "!git clone https://github.com/huggingface/diffusers.git\n",
- "!pip install -q /content/diffusers\n",
- "\n",
- "# dependencies for diffusers\n",
- "!pip install -q datasets transformers"
+ "text/plain": [
+ "Downloading: 0%| | 0.00/401 [00:00, ?B/s]"
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
+ "Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\n",
+ "- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "DEVICE = \"cuda\"\n",
+ "model = MoleculeGNN.from_pretrained(\"fusing/gfn-molecule-gen-drugs\").to(DEVICE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HdclRaqoUWUD"
+ },
+ "source": [
+ "The warnings above are because the pre-trained model was uploaded before cleaning the code!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PlOkPySoJ1m9"
+ },
+ "source": [
+ "#### Create scheduler\n",
+ "Note, other schedulers are used in the paper for slightly improved performance over DDPM."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nNHnIk9CkAb2"
+ },
+ "outputs": [],
+ "source": [
+ "from diffusers import DDPMScheduler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RnDJdDBztjFF"
+ },
+ "outputs": [],
+ "source": [
+ "num_timesteps = 1000\n",
+ "scheduler = DDPMScheduler(\n",
+ " num_train_timesteps=num_timesteps, beta_schedule=\"sigmoid\", beta_start=1e-7, beta_end=2e-3, clip_sample=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1vh3fpSAflkL"
+ },
+ "source": [
+ "### Get a dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B6qzaGjVKFVk"
+ },
+ "source": [
+ "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n",
+ "\n",
+ "(direct downloading from the hub does not yet work for this datatype)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jbLl3EJdgj3x"
+ },
+ "outputs": [],
+ "source": [
+ "# from google.colab import files"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "E591lVuTgxPE"
+ },
+ "outputs": [],
+ "source": [
+ "# uploaded = files.upload()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KUNxfK3ln98Q"
+ },
+ "source": [
+ "Load the dataset with torch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "7L4iOShTpcQX",
+ "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--2022-10-12 18:32:19-- https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
+ "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n",
+ "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 127774 (125K) [application/octet-stream]\n",
+ "Saving to: ‘molecules.pkl’\n",
+ "\n",
+ "molecules.pkl 100%[===================>] 124.78K 180KB/s in 0.7s \n",
+ "\n",
+ "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "\n",
+ "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
+ "dataset = torch.load(\"/content/molecules.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QZcmy1EvKQRk"
+ },
+ "source": [
+ "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "JVjz6iH_H6Eh",
+ "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "LZO6AJKuJKO8"
- },
- "source": [
- "Check that torch is installed correctly and utilizing the GPU in the colab"
+ "data": {
+ "text/plain": [
+ "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "gZt7BNi1e1PA",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 53
- },
- "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "True\n"
- ]
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "'1.8.2'"
- ],
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- }
- },
- "metadata": {},
- "execution_count": 8
- }
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vHNiZAUxNgoy"
+ },
+ "source": [
+ "## Run the diffusion process"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jZ1KZrxKqENg"
+ },
+ "source": [
+ "#### Helper Functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "s240tYueqKKf"
+ },
+ "outputs": [],
+ "source": [
+ "import copy\n",
+ "import os\n",
+ "\n",
+ "from torch_geometric.data import Batch, Data\n",
+ "from torch_scatter import scatter_mean\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "\n",
+ "def repeat_data(data: Data, num_repeat) -> Batch:\n",
+ " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n",
+ " return Batch.from_data_list(datas)\n",
+ "\n",
+ "\n",
+ "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n",
+ " datas = batch.to_data_list()\n",
+ " new_data = []\n",
+ " for i in range(num_repeat):\n",
+ " new_data += copy.deepcopy(datas)\n",
+ " return Batch.from_data_list(new_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AMnQTk0eqT7Z"
+ },
+ "source": [
+ "#### Constants"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WYGkzqgzrHmF"
+ },
+ "outputs": [],
+ "source": [
+ "num_samples = 1 # solutions per molecule\n",
+ "num_molecules = 3\n",
+ "\n",
+ "DEVICE = \"cuda\"\n",
+ "sampling_type = \"ddpm_noisy\" #'' # paper also uses \"generalize\" and \"ld\"\n",
+ "# constants for inference\n",
+ "w_global = 0.5 # 0,.3 for qm9\n",
+ "global_start_sigma = 0.5\n",
+ "eta = 1.0\n",
+ "clip_local = None\n",
+ "clip_pos = None\n",
+ "\n",
+ "# constands for data handling\n",
+ "save_traj = False\n",
+ "save_data = False\n",
+ "output_dir = \"/content/\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-xD5bJ3SqM7t"
+ },
+ "source": [
+ "#### Generate samples!\n",
+ "Note that the 3d representation of a molecule is referred to as the **conformation**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "x9xuLUNg26z1",
+ "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " after removing the cwd from sys.path.\n",
+ "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pickle\n",
+ "\n",
+ "\n",
+ "results = []\n",
+ "\n",
+ "# define sigmas\n",
+ "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n",
+ "sigmas = sigmas.to(DEVICE)\n",
+ "\n",
+ "for count, data in enumerate(tqdm(dataset)):\n",
+ " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n",
+ "\n",
+ " data_input = data.clone()\n",
+ " data_input[\"pos_ref\"] = None\n",
+ " batch = repeat_data(data_input, num_samples).to(DEVICE)\n",
+ "\n",
+ " # initial configuration\n",
+ " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n",
+ "\n",
+ " # for logging animation of denoising\n",
+ " pos_traj = []\n",
+ " with torch.no_grad():\n",
+ " # scale initial sample\n",
+ " pos = pos_init * sigmas[-1]\n",
+ " for t in scheduler.timesteps:\n",
+ " batch.pos = pos\n",
+ "\n",
+ " # generate geometry with model, then filter it\n",
+ " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n",
+ "\n",
+ " # Update\n",
+ " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n",
+ "\n",
+ " pos = reconstructed_pos\n",
+ "\n",
+ " if torch.isnan(pos).any():\n",
+ " print(\"NaN detected. Please restart.\")\n",
+ " raise FloatingPointError()\n",
+ "\n",
+ " # recenter graph of positions for next iteration\n",
+ " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n",
+ "\n",
+ " # optional clipping\n",
+ " if clip_pos is not None:\n",
+ " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n",
+ " pos_traj.append(pos.clone().cpu())\n",
+ "\n",
+ " pos_gen = pos.cpu()\n",
+ " if save_traj:\n",
+ " pos_gen_traj = pos_traj.cpu()\n",
+ " data.pos_gen = torch.stack(pos_gen_traj)\n",
+ " else:\n",
+ " data.pos_gen = pos_gen\n",
+ " results.append(data)\n",
+ "\n",
+ "\n",
+ "if save_data:\n",
+ " save_path = os.path.join(output_dir, \"samples_all.pkl\")\n",
+ "\n",
+ " with open(save_path, \"wb\") as f:\n",
+ " pickle.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fSApwSaZNndW"
+ },
+ "source": [
+ "## Render the results!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d47Zxo2OKdgZ"
+ },
+ "source": [
+ "This function allows us to render 3d in colab."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e9Cd0kCAv9b8"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import output\n",
+ "\n",
+ "\n",
+ "output.enable_custom_widget_manager()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RjaVuR15NqzF"
+ },
+ "source": [
+ "### Helper functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "28rBYa9NKhlz"
+ },
+ "source": [
+ "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LKdKdwxcyTQ6"
+ },
+ "outputs": [],
+ "source": [
+ "from copy import deepcopy\n",
+ "\n",
+ "\n",
+ "def set_rdmol_positions(rdkit_mol, pos):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
+ " pos: (N_atoms, 3)\n",
+ " \"\"\"\n",
+ " mol = deepcopy(rdkit_mol)\n",
+ " set_rdmol_positions_(mol, pos)\n",
+ " return mol\n",
+ "\n",
+ "\n",
+ "def set_rdmol_positions_(mol, pos):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
+ " pos: (N_atoms, 3)\n",
+ " \"\"\"\n",
+ " for i in range(pos.shape[0]):\n",
+ " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n",
+ " return mol"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NuE10hcpKmzK"
+ },
+ "source": [
+ "Process the generated data to make it easy to view."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "KieVE1vc0_Vs",
+ "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "collect 5 generated molecules in `mols`\n"
+ ]
+ }
+ ],
+ "source": [
+ "# the model can generate multiple conformations per 2d geometry\n",
+ "num_gen = results[0][\"pos_gen\"].shape[0]\n",
+ "\n",
+ "# init storage objects\n",
+ "mols_gen = []\n",
+ "mols_orig = []\n",
+ "for to_process in results:\n",
+ " # store the reference 3d position\n",
+ " to_process[\"pos_ref\"] = to_process[\"pos_ref\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n",
+ "\n",
+ " # store the generated 3d position\n",
+ " to_process[\"pos_gen\"] = to_process[\"pos_gen\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n",
+ "\n",
+ " # copy data to new object\n",
+ " new_mol = set_rdmol_positions(to_process.rdmol, to_process[\"pos_gen\"][0])\n",
+ "\n",
+ " # append results\n",
+ " mols_gen.append(new_mol)\n",
+ " mols_orig.append(to_process.rdmol)\n",
+ "\n",
+ "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tin89JwMKp4v"
+ },
+ "source": [
+ "Import tools to visualize the 2d chemical diagram of the molecule."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yqV6gllSZn38"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import SVG, display\n",
+ "from rdkit import Chem\n",
+ "from rdkit.Chem.Draw import rdMolDraw2D as MD2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TFNKmGddVoOk"
+ },
+ "source": [
+ "Select molecule to visualize"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KzuwLlrrVaGc"
+ },
+ "outputs": [],
+ "source": [
+ "idx = 0\n",
+ "assert idx < len(results), \"selected molecule that was not generated\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hkb8w0_SNtU8"
+ },
+ "source": [
+ "### Viewing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I3R4QBQeKttN"
+ },
+ "source": [
+ "This 2D rendering is the equivalent of the **input to the model**!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 321
+ },
+ "id": "gkQRWjraaKex",
+ "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ ""
],
- "source": [
- "import torch\n",
- "print(torch.cuda.is_available())\n",
- "torch.__version__"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KLE7CqlfJNUO"
- },
- "source": [
- "### Install Chemistry-specific Dependencies\n",
- "\n",
- "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)."
+ "text/plain": [
+ ""
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "0CPv_NvehRz3",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc"
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "mc = Chem.MolFromSmiles(dataset[0][\"smiles\"])\n",
+ "molSize = (450, 300)\n",
+ "drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1])\n",
+ "drawer.DrawMolecule(mc)\n",
+ "drawer.FinishDrawing()\n",
+ "svg = drawer.GetDrawingText()\n",
+ "display(SVG(svg.replace(\"svg:\", \"\")))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z4FDMYMxKw2I"
+ },
+ "source": [
+ "Generate the 3d molecule!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17,
+ "referenced_widgets": [
+ "695ab5bbf30a4ab19df1f9f33469f314",
+ "eac6a8dcdc9d4335a2e51031793ead29"
+ ]
+ },
+ "id": "aT1Bkb8YxJfV",
+ "outputId": "b98870ae-049d-4386-b676-166e9526bda2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "695ab5bbf30a4ab19df1f9f33469f314",
+ "version_major": 2,
+ "version_minor": 0
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting rdkit\n",
- " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n",
- "Installing collected packages: rdkit\n",
- "Successfully installed rdkit-2022.3.5\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
+ "text/plain": []
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
}
- ],
- "source": [
- "!pip install rdkit"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "88GaDbDPxJ5I"
+ }
+ }
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from nglview import show_rdkit as show"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 337,
+ "referenced_widgets": [
+ "be446195da2b4ff2aec21ec5ff963a54",
+ "c6596896148b4a8a9c57963b67c7782f",
+ "2489b5e5648541fbbdceadb05632a050",
+ "01e0ba4e5da04914b4652b8d58565d7b",
+ "c30e6c2f3e2a44dbbb3d63bd519acaa4",
+ "f31c6e40e9b2466a9064a2669933ecd5",
+ "19308ccac642498ab8b58462e3f1b0bb",
+ "4a081cdc2ec3421ca79dd933b7e2b0c4",
+ "e5c0d75eb5e1447abd560c8f2c6017e1",
+ "5146907ef6764654ad7d598baebc8b58",
+ "144ec959b7604a2cabb5ca46ae5e5379",
+ "abce2a80e6304df3899109c6d6cac199",
+ "65195cb7a4134f4887e9dd19f3676462"
+ ]
+ },
+ "id": "pxtq8I-I18C-",
+ "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "be446195da2b4ff2aec21ec5ff963a54",
+ "version_major": 2,
+ "version_minor": 0
},
- "source": [
- "### Get viewer from nglview\n",
- "\n",
- "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n",
- "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n",
- "The rdmol in this object is a source of ground truth for the generated molecules.\n",
- "\n",
- "You will use one rendering function from nglviewer later!\n",
- "\n"
+ "text/plain": [
+ "NGLWidget()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jcl8GCS2mz6t",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
- },
- "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting nglview\n",
- " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n",
- "Collecting jupyterlab-widgets\n",
- " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ipywidgets>=7\n",
- " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting widgetsnbextension~=4.0\n",
- " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ipython>=6.1.0\n",
- " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ipykernel>=4.5.1\n",
- " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting traitlets>=4.3.1\n",
- " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n",
- "Collecting pyzmq>=17\n",
- " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting matplotlib-inline>=0.1\n",
- " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n",
- "Collecting tornado>=6.1\n",
- " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting nest-asyncio\n",
- " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n",
- "Collecting debugpy>=1.0\n",
- " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting psutil\n",
- " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting jupyter-client>=6.1.12\n",
- " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting pickleshare\n",
- " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n",
- "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n",
- "Collecting backcall\n",
- " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n",
- "Collecting pexpect>4.3\n",
- " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting pygments\n",
- " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting jedi>=0.16\n",
- " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n",
- " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n",
- "Collecting parso<0.9.0,>=0.8.0\n",
- " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n",
- "Collecting entrypoints\n",
- " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n",
- "Collecting jupyter-core>=4.9.2\n",
- " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ptyprocess>=0.5\n",
- " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n",
- "Collecting wcwidth\n",
- " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n",
- "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n",
- "Building wheels for collected packages: nglview\n",
- " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n",
- " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n",
- "Successfully built nglview\n",
- "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n",
- "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- },
- {
- "output_type": "display_data",
- "data": {
- "application/vnd.colab-display-data+json": {
- "pip_warning": {
- "packages": [
- "pexpect",
- "pickleshare",
- "wcwidth"
- ]
- }
- }
- },
- "metadata": {}
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
}
- ],
- "source": [
- "!pip install nglview"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Create a diffusion model"
- ],
- "metadata": {
- "id": "8t8_e_uVLdKB"
+ }
}
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Model class(es)"
- ],
- "metadata": {
- "id": "G0rMncVtNSqU"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Imports"
- ],
- "metadata": {
- "id": "L5FEXz5oXkzt"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n",
- "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n",
- "from dataclasses import dataclass\n",
- "from typing import Callable, Tuple, Union\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "import torch.nn.functional as F\n",
- "from torch import Tensor, nn\n",
- "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n",
- "\n",
- "from torch_geometric.nn import MessagePassing, radius, radius_graph\n",
- "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n",
- "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n",
- "from torch_scatter import scatter_add\n",
- "from torch_sparse import SparseTensor, coalesce\n",
- "\n",
- "from diffusers.configuration_utils import ConfigMixin, register_to_config\n",
- "from diffusers.modeling_utils import ModelMixin\n",
- "from diffusers.utils import BaseOutput\n"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# new molecule\n",
+ "show(mols_gen[idx])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KJr4h2mwXeTo"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "01e0ba4e5da04914b4652b8d58565d7b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1",
+ "IPY_MODEL_5146907ef6764654ad7d598baebc8b58"
],
- "metadata": {
- "id": "-3-P4w5sXkRU"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Helper classes"
+ "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379"
+ }
+ },
+ "144ec959b7604a2cabb5ca46ae5e5379": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "19308ccac642498ab8b58462e3f1b0bb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1c6246f15b654f4daa11c9bcf997b78c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df",
+ "placeholder": "",
+ "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c",
+ "value": "Downloading: 100%"
+ }
+ },
+ "2489b5e5648541fbbdceadb05632a050": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ButtonModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ButtonView",
+ "button_style": "",
+ "description": "",
+ "disabled": false,
+ "icon": "compress",
+ "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199",
+ "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462",
+ "tooltip": ""
+ }
+ },
+ "24d31fc3576e43dd9f8301d2ef3a37ab": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2918bfaadc8d4b1a9832522c40dfefb8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c9362906e4b40189f16d14aa9a348da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "42f6c3db29d7484ba6b4f73590abd2f4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4a081cdc2ec3421ca79dd933b7e2b0c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "SliderStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "SliderStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": "",
+ "handle_color": null
+ }
+ },
+ "5146907ef6764654ad7d598baebc8b58": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "IntSliderModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "IntSliderModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "IntSliderView",
+ "continuous_update": true,
+ "description": "",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb",
+ "max": 0,
+ "min": 0,
+ "orientation": "horizontal",
+ "readout": true,
+ "readout_format": "d",
+ "step": 1,
+ "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4",
+ "value": 0
+ }
+ },
+ "561f742d418d4721b0670cc8dd62e22c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6010fc8daa7a44d5aec4b830ec2ebaa1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2",
+ "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
+ "IPY_MODEL_6526646be5ed415c84d1245b040e629b"
],
- "metadata": {
- "id": "EzJQXPN_XrMX"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "@dataclass\n",
- "class MoleculeGNNOutput(BaseOutput):\n",
- " \"\"\"\n",
- " Args:\n",
- " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
- " Hidden states output. Output of last layer of model.\n",
- " \"\"\"\n",
- "\n",
- " sample: torch.Tensor\n",
- "\n",
- "\n",
- "class MultiLayerPerceptron(nn.Module):\n",
- " \"\"\"\n",
- " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n",
- " Args:\n",
- " input_dim (int): input dimension\n",
- " hidden_dim (list of int): hidden dimensions\n",
- " activation (str or function, optional): activation function\n",
- " dropout (float, optional): dropout rate\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n",
- " super(MultiLayerPerceptron, self).__init__()\n",
- "\n",
- " self.dims = [input_dim] + hidden_dims\n",
- " if isinstance(activation, str):\n",
- " self.activation = getattr(F, activation)\n",
- " else:\n",
- " print(f\"Warning, activation passed {activation} is not string and ignored\")\n",
- " self.activation = None\n",
- " if dropout > 0:\n",
- " self.dropout = nn.Dropout(dropout)\n",
- " else:\n",
- " self.dropout = None\n",
- "\n",
- " self.layers = nn.ModuleList()\n",
- " for i in range(len(self.dims) - 1):\n",
- " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n",
- "\n",
- " def forward(self, x):\n",
- " \"\"\"\"\"\"\n",
- " for i, layer in enumerate(self.layers):\n",
- " x = layer(x)\n",
- " if i < len(self.layers) - 1:\n",
- " if self.activation:\n",
- " x = self.activation(x)\n",
- " if self.dropout:\n",
- " x = self.dropout(x)\n",
- " return x\n",
- "\n",
- "\n",
- "class ShiftedSoftplus(torch.nn.Module):\n",
- " def __init__(self):\n",
- " super(ShiftedSoftplus, self).__init__()\n",
- " self.shift = torch.log(torch.tensor(2.0)).item()\n",
- "\n",
- " def forward(self, x):\n",
- " return F.softplus(x) - self.shift\n",
- "\n",
- "\n",
- "class CFConv(MessagePassing):\n",
- " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n",
- " super(CFConv, self).__init__(aggr=\"add\")\n",
- " self.lin1 = Linear(in_channels, num_filters, bias=False)\n",
- " self.lin2 = Linear(num_filters, out_channels)\n",
- " self.nn = mlp\n",
- " self.cutoff = cutoff\n",
- " self.smooth = smooth\n",
- "\n",
- " self.reset_parameters()\n",
- "\n",
- " def reset_parameters(self):\n",
- " torch.nn.init.xavier_uniform_(self.lin1.weight)\n",
- " torch.nn.init.xavier_uniform_(self.lin2.weight)\n",
- " self.lin2.bias.data.fill_(0)\n",
- "\n",
- " def forward(self, x, edge_index, edge_length, edge_attr):\n",
- " if self.smooth:\n",
- " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n",
- " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n",
- " else:\n",
- " C = (edge_length <= self.cutoff).float()\n",
- " W = self.nn(edge_attr) * C.view(-1, 1)\n",
- "\n",
- " x = self.lin1(x)\n",
- " x = self.propagate(edge_index, x=x, W=W)\n",
- " x = self.lin2(x)\n",
- " return x\n",
- "\n",
- " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n",
- " return x_j * W\n",
- "\n",
- "\n",
- "class InteractionBlock(torch.nn.Module):\n",
- " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n",
- " super(InteractionBlock, self).__init__()\n",
- " mlp = Sequential(\n",
- " Linear(num_gaussians, num_filters),\n",
- " ShiftedSoftplus(),\n",
- " Linear(num_filters, num_filters),\n",
- " )\n",
- " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n",
- " self.act = ShiftedSoftplus()\n",
- " self.lin = Linear(hidden_channels, hidden_channels)\n",
- "\n",
- " def forward(self, x, edge_index, edge_length, edge_attr):\n",
- " x = self.conv(x, edge_index, edge_length, edge_attr)\n",
- " x = self.act(x)\n",
- " x = self.lin(x)\n",
- " return x\n",
- "\n",
- "\n",
- "class SchNetEncoder(Module):\n",
- " def __init__(\n",
- " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n",
- " ):\n",
- " super().__init__()\n",
- "\n",
- " self.hidden_channels = hidden_channels\n",
- " self.num_filters = num_filters\n",
- " self.num_interactions = num_interactions\n",
- " self.cutoff = cutoff\n",
- "\n",
- " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n",
- "\n",
- " self.interactions = ModuleList()\n",
- " for _ in range(num_interactions):\n",
- " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n",
- " self.interactions.append(block)\n",
- "\n",
- " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n",
- " if embed_node:\n",
- " assert z.dim() == 1 and z.dtype == torch.long\n",
- " h = self.embedding(z)\n",
- " else:\n",
- " h = z\n",
- " for interaction in self.interactions:\n",
- " h = h + interaction(h, edge_index, edge_length, edge_attr)\n",
- "\n",
- " return h\n",
- "\n",
- "\n",
- "class GINEConv(MessagePassing):\n",
- " \"\"\"\n",
- " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n",
- " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n",
- " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n",
- " self.nn = mlp\n",
- " self.initial_eps = eps\n",
- "\n",
- " if isinstance(activation, str):\n",
- " self.activation = getattr(F, activation)\n",
- " else:\n",
- " self.activation = None\n",
- "\n",
- " if train_eps:\n",
- " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n",
- " else:\n",
- " self.register_buffer(\"eps\", torch.Tensor([eps]))\n",
- "\n",
- " def forward(\n",
- " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n",
- " ) -> torch.Tensor:\n",
- " \"\"\"\"\"\"\n",
- " if isinstance(x, torch.Tensor):\n",
- " x: OptPairTensor = (x, x)\n",
- "\n",
- " # Node and edge feature dimensionalites need to match.\n",
- " if isinstance(edge_index, torch.Tensor):\n",
- " assert edge_attr is not None\n",
- " assert x[0].size(-1) == edge_attr.size(-1)\n",
- " elif isinstance(edge_index, SparseTensor):\n",
- " assert x[0].size(-1) == edge_index.size(-1)\n",
- "\n",
- " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n",
- " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n",
- "\n",
- " x_r = x[1]\n",
- " if x_r is not None:\n",
- " out += (1 + self.eps) * x_r\n",
- "\n",
- " return self.nn(out)\n",
- "\n",
- " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n",
- " if self.activation:\n",
- " return self.activation(x_j + edge_attr)\n",
- " else:\n",
- " return x_j + edge_attr\n",
- "\n",
- " def __repr__(self):\n",
- " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n",
- "\n",
- "\n",
- "class GINEncoder(torch.nn.Module):\n",
- " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n",
- " super().__init__()\n",
- "\n",
- " self.hidden_dim = hidden_dim\n",
- " self.num_convs = num_convs\n",
- " self.short_cut = short_cut\n",
- " self.concat_hidden = concat_hidden\n",
- " self.node_emb = nn.Embedding(100, hidden_dim)\n",
- "\n",
- " if isinstance(activation, str):\n",
- " self.activation = getattr(F, activation)\n",
- " else:\n",
- " self.activation = None\n",
- "\n",
- " self.convs = nn.ModuleList()\n",
- " for i in range(self.num_convs):\n",
- " self.convs.append(\n",
- " GINEConv(\n",
- " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n",
- " activation=activation,\n",
- " )\n",
- " )\n",
- "\n",
- " def forward(self, z, edge_index, edge_attr):\n",
- " \"\"\"\n",
- " Input:\n",
- " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n",
- " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n",
- " Output:\n",
- " node_feature: graph feature\n",
- " \"\"\"\n",
- "\n",
- " node_attr = self.node_emb(z) # (num_node, hidden)\n",
- "\n",
- " hiddens = []\n",
- " conv_input = node_attr # (num_node, hidden)\n",
- "\n",
- " for conv_idx, conv in enumerate(self.convs):\n",
- " hidden = conv(conv_input, edge_index, edge_attr)\n",
- " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n",
- " hidden = self.activation(hidden)\n",
- " assert hidden.shape == conv_input.shape\n",
- " if self.short_cut and hidden.shape == conv_input.shape:\n",
- " hidden += conv_input\n",
- "\n",
- " hiddens.append(hidden)\n",
- " conv_input = hidden\n",
- "\n",
- " if self.concat_hidden:\n",
- " node_feature = torch.cat(hiddens, dim=-1)\n",
- " else:\n",
- " node_feature = hiddens[-1]\n",
- "\n",
- " return node_feature\n",
- "\n",
- "\n",
- "class MLPEdgeEncoder(Module):\n",
- " def __init__(self, hidden_dim=100, activation=\"relu\"):\n",
- " super().__init__()\n",
- " self.hidden_dim = hidden_dim\n",
- " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n",
- " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n",
- "\n",
- " @property\n",
- " def out_channels(self):\n",
- " return self.hidden_dim\n",
- "\n",
- " def forward(self, edge_length, edge_type):\n",
- " \"\"\"\n",
- " Input:\n",
- " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n",
- " Returns:\n",
- " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n",
- " \"\"\"\n",
- " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n",
- " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n",
- " return d_emb * edge_attr # (num_edge, hidden)\n",
- "\n",
- "\n",
- "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n",
- " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n",
- " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n",
- " return h_pair\n",
- "\n",
- "\n",
- "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n",
- " \"\"\"\n",
- " Args:\n",
- " num_nodes: Number of atoms.\n",
- " edge_index: Bond indices of the original graph.\n",
- " edge_type: Bond types of the original graph.\n",
- " order: Extension order.\n",
- " Returns:\n",
- " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n",
- " \"\"\"\n",
- "\n",
- " def binarize(x):\n",
- " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n",
- "\n",
- " def get_higher_order_adj_matrix(adj, order):\n",
- " \"\"\"\n",
- " Args:\n",
- " adj: (N, N)\n",
- " type_mat: (N, N)\n",
- " Returns:\n",
- " Following attributes will be updated:\n",
- " - edge_index\n",
- " - edge_type\n",
- " Following attributes will be added to the data object:\n",
- " - bond_edge_index: Original edge_index.\n",
- " \"\"\"\n",
- " adj_mats = [\n",
- " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n",
- " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n",
- " ]\n",
- "\n",
- " for i in range(2, order + 1):\n",
- " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n",
- " order_mat = torch.zeros_like(adj)\n",
- "\n",
- " for i in range(1, order + 1):\n",
- " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n",
- "\n",
- " return order_mat\n",
- "\n",
- " num_types = 22\n",
- " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n",
- " # from rdkit.Chem.rdchem import BondType as BT\n",
- " N = num_nodes\n",
- " adj = to_dense_adj(edge_index).squeeze(0)\n",
- " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n",
- "\n",
- " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n",
- " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n",
- " assert (type_mat * type_highorder == 0).all()\n",
- " type_new = type_mat + type_highorder\n",
- "\n",
- " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n",
- " _, edge_order = dense_to_sparse(adj_order)\n",
- "\n",
- " # data.bond_edge_index = data.edge_index # Save original edges\n",
- " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n",
- "\n",
- " return new_edge_index, new_edge_type\n",
- "\n",
- "\n",
- "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n",
- " assert edge_type.dim() == 1\n",
- " N = pos.size(0)\n",
- "\n",
- " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n",
- "\n",
- " if is_sidechain is None:\n",
- " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n",
- " else:\n",
- " # fetch sidechain and its batch index\n",
- " is_sidechain = is_sidechain.bool()\n",
- " dummy_index = torch.arange(pos.size(0), device=pos.device)\n",
- " sidechain_pos = pos[is_sidechain]\n",
- " sidechain_index = dummy_index[is_sidechain]\n",
- " sidechain_batch = batch[is_sidechain]\n",
- "\n",
- " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n",
- " r_edge_index_x = assign_index[1]\n",
- " r_edge_index_y = assign_index[0]\n",
- " r_edge_index_y = sidechain_index[r_edge_index_y]\n",
- "\n",
- " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n",
- " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n",
- " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n",
- " # delete self loop\n",
- " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n",
- "\n",
- " rgraph_adj = torch.sparse.LongTensor(\n",
- " rgraph_edge_index,\n",
- " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n",
- " torch.Size([N, N]),\n",
- " )\n",
- "\n",
- " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n",
- "\n",
- " new_edge_index = composed_adj.indices()\n",
- " new_edge_type = composed_adj.values().long()\n",
- "\n",
- " return new_edge_index, new_edge_type\n",
- "\n",
- "\n",
- "def extend_graph_order_radius(\n",
- " num_nodes,\n",
- " pos,\n",
- " edge_index,\n",
- " edge_type,\n",
- " batch,\n",
- " order=3,\n",
- " cutoff=10.0,\n",
- " extend_order=True,\n",
- " extend_radius=True,\n",
- " is_sidechain=None,\n",
- "):\n",
- " if extend_order:\n",
- " edge_index, edge_type = _extend_graph_order(\n",
- " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n",
- " )\n",
- "\n",
- " if extend_radius:\n",
- " edge_index, edge_type = _extend_to_radius_graph(\n",
- " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n",
- " )\n",
- "\n",
- " return edge_index, edge_type\n",
- "\n",
- "\n",
- "def get_distance(pos, edge_index):\n",
- " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n",
- "\n",
- "\n",
- "def graph_field_network(score_d, pos, edge_index, edge_length):\n",
- " \"\"\"\n",
- " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n",
- " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n",
- " \"\"\"\n",
- " N = pos.size(0)\n",
- " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n",
- " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n",
- " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n",
- " ) # (N, 3)\n",
- " return score_pos\n",
- "\n",
- "\n",
- "def clip_norm(vec, limit, p=2):\n",
- " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n",
- " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n",
- " return vec * denom\n",
- "\n",
- "\n",
- "def is_local_edge(edge_type):\n",
- " return edge_type > 0\n"
+ "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab"
+ }
+ },
+ "65195cb7a4134f4887e9dd19f3676462": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ButtonStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "button_color": null,
+ "font_weight": ""
+ }
+ },
+ "6526646be5ed415c84d1245b040e629b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433",
+ "placeholder": "",
+ "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4",
+ "value": " 401/401 [00:00<00:00, 13.5kB/s]"
+ }
+ },
+ "695ab5bbf30a4ab19df1f9f33469f314": {
+ "model_module": "nglview-js-widgets",
+ "model_module_version": "3.0.1",
+ "model_name": "ColormakerRegistryModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "nglview-js-widgets",
+ "_model_module_version": "3.0.1",
+ "_model_name": "ColormakerRegistryModel",
+ "_msg_ar": [],
+ "_msg_q": [],
+ "_ready": false,
+ "_view_count": null,
+ "_view_module": "nglview-js-widgets",
+ "_view_module_version": "3.0.1",
+ "_view_name": "ColormakerRegistryView",
+ "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29"
+ }
+ },
+ "7e0bb1b8d65249d3974200686b193be2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8",
+ "placeholder": "",
+ "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08",
+ "value": "Downloading: 100%"
+ }
+ },
+ "872915dd1bb84f538c44e26badabafdd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a4bfdca35cc54dae8812720f1b276a08": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a9e388f22a9742aaaf538e22575c9433": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "abce2a80e6304df3899109c6d6cac199": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "34px"
+ }
+ },
+ "b7feb522161f4cf4b7cc7c1a078ff12d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da",
+ "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]"
+ }
+ },
+ "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39",
+ "max": 401,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_f915cf874246446595206221e900b2fe",
+ "value": 401
+ }
+ },
+ "bbef741e76ec41b7ab7187b487a383df": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be446195da2b4ff2aec21ec5ff963a54": {
+ "model_module": "nglview-js-widgets",
+ "model_module_version": "3.0.1",
+ "model_name": "NGLModel",
+ "state": {
+ "_camera_orientation": [
+ -15.519693580202304,
+ -14.065056548036177,
+ -23.53197484807691,
+ 0,
+ -23.357853515109753,
+ 20.94055073042662,
+ 2.888695042134944,
+ 0,
+ 14.352363398292775,
+ 18.870825741878015,
+ -20.744689572909344,
+ 0,
+ 0.2724999189376831,
+ 0.6940000057220459,
+ -0.3734999895095825,
+ 1
],
- "metadata": {
- "id": "oR1Y56QiLY90"
+ "_camera_str": "orthographic",
+ "_dom_classes": [],
+ "_gui_theme": null,
+ "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050",
+ "_igui": null,
+ "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b",
+ "_model_module": "nglview-js-widgets",
+ "_model_module_version": "3.0.1",
+ "_model_name": "NGLModel",
+ "_ngl_color_dict": {},
+ "_ngl_coordinate_resource": {},
+ "_ngl_full_stage_parameters": {
+ "ambientColor": 14540253,
+ "ambientIntensity": 0.2,
+ "backgroundColor": "white",
+ "cameraEyeSep": 0.3,
+ "cameraFov": 40,
+ "cameraType": "perspective",
+ "clipDist": 10,
+ "clipFar": 100,
+ "clipNear": 0,
+ "fogFar": 100,
+ "fogNear": 50,
+ "hoverTimeout": 0,
+ "impostor": true,
+ "lightColor": 14540253,
+ "lightIntensity": 1,
+ "mousePreset": "default",
+ "panSpeed": 1,
+ "quality": "medium",
+ "rotateSpeed": 2,
+ "sampleLevel": 0,
+ "tooltip": true,
+ "workerDefault": true,
+ "zoomSpeed": 1.2
},
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Main model class!"
- ],
- "metadata": {
- "id": "QWrHJFcYXyUB"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class MoleculeGNN(ModelMixin, ConfigMixin):\n",
- " @register_to_config\n",
- " def __init__(\n",
- " self,\n",
- " hidden_dim=128,\n",
- " num_convs=6,\n",
- " num_convs_local=4,\n",
- " cutoff=10.0,\n",
- " mlp_act=\"relu\",\n",
- " edge_order=3,\n",
- " edge_encoder=\"mlp\",\n",
- " smooth_conv=True,\n",
- " ):\n",
- " super().__init__()\n",
- " self.cutoff = cutoff\n",
- " self.edge_encoder = edge_encoder\n",
- " self.edge_order = edge_order\n",
- "\n",
- " \"\"\"\n",
- " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n",
- " in SchNetEncoder\n",
- " \"\"\"\n",
- " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
- " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
- "\n",
- " \"\"\"\n",
- " The graph neural network that extracts node-wise features.\n",
- " \"\"\"\n",
- " self.encoder_global = SchNetEncoder(\n",
- " hidden_channels=hidden_dim,\n",
- " num_filters=hidden_dim,\n",
- " num_interactions=num_convs,\n",
- " edge_channels=self.edge_encoder_global.out_channels,\n",
- " cutoff=cutoff,\n",
- " smooth=smooth_conv,\n",
- " )\n",
- " self.encoder_local = GINEncoder(\n",
- " hidden_dim=hidden_dim,\n",
- " num_convs=num_convs_local,\n",
- " )\n",
- "\n",
- " \"\"\"\n",
- " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n",
- " gradients w.r.t. edge_length (out_dim = 1).\n",
- " \"\"\"\n",
- " self.grad_global_dist_mlp = MultiLayerPerceptron(\n",
- " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
- " )\n",
- "\n",
- " self.grad_local_dist_mlp = MultiLayerPerceptron(\n",
- " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
- " )\n",
- "\n",
- " \"\"\"\n",
- " Incorporate parameters together\n",
- " \"\"\"\n",
- " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n",
- " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n",
- "\n",
- " def _forward(\n",
- " self,\n",
- " atom_type,\n",
- " pos,\n",
- " bond_index,\n",
- " bond_type,\n",
- " batch,\n",
- " time_step, # NOTE, model trained without timestep performed best\n",
- " edge_index=None,\n",
- " edge_type=None,\n",
- " edge_length=None,\n",
- " return_edges=False,\n",
- " extend_order=True,\n",
- " extend_radius=True,\n",
- " is_sidechain=None,\n",
- " ):\n",
- " \"\"\"\n",
- " Args:\n",
- " atom_type: Types of atoms, (N, ).\n",
- " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n",
- " bond_type: Bond types, (E, ).\n",
- " batch: Node index to graph index, (N, ).\n",
- " \"\"\"\n",
- " N = atom_type.size(0)\n",
- " if edge_index is None or edge_type is None or edge_length is None:\n",
- " edge_index, edge_type = extend_graph_order_radius(\n",
- " num_nodes=N,\n",
- " pos=pos,\n",
- " edge_index=bond_index,\n",
- " edge_type=bond_type,\n",
- " batch=batch,\n",
- " order=self.edge_order,\n",
- " cutoff=self.cutoff,\n",
- " extend_order=extend_order,\n",
- " extend_radius=extend_radius,\n",
- " is_sidechain=is_sidechain,\n",
- " )\n",
- " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n",
- " local_edge_mask = is_local_edge(edge_type) # (E, )\n",
- "\n",
- " # with the parameterization of NCSNv2\n",
- " # DDPM loss implicit handle the noise variance scale conditioning\n",
- " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n",
- "\n",
- " # Encoding global\n",
- " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
- "\n",
- " # Global\n",
- " node_attr_global = self.encoder_global(\n",
- " z=atom_type,\n",
- " edge_index=edge_index,\n",
- " edge_length=edge_length,\n",
- " edge_attr=edge_attr_global,\n",
- " )\n",
- " # Assemble pairwise features\n",
- " h_pair_global = assemble_atom_pair_feature(\n",
- " node_attr=node_attr_global,\n",
- " edge_index=edge_index,\n",
- " edge_attr=edge_attr_global,\n",
- " ) # (E_global, 2H)\n",
- " # Invariant features of edges (radius graph, global)\n",
- " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n",
- "\n",
- " # Encoding local\n",
- " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
- " # edge_attr += temb_edge\n",
- "\n",
- " # Local\n",
- " node_attr_local = self.encoder_local(\n",
- " z=atom_type,\n",
- " edge_index=edge_index[:, local_edge_mask],\n",
- " edge_attr=edge_attr_local[local_edge_mask],\n",
- " )\n",
- " # Assemble pairwise features\n",
- " h_pair_local = assemble_atom_pair_feature(\n",
- " node_attr=node_attr_local,\n",
- " edge_index=edge_index[:, local_edge_mask],\n",
- " edge_attr=edge_attr_local[local_edge_mask],\n",
- " ) # (E_local, 2H)\n",
- "\n",
- " # Invariant features of edges (bond graph, local)\n",
- " if isinstance(sigma_edge, torch.Tensor):\n",
- " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n",
- " 1.0 / sigma_edge[local_edge_mask]\n",
- " ) # (E_local, 1)\n",
- " else:\n",
- " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n",
- "\n",
- " if return_edges:\n",
- " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n",
- " else:\n",
- " return edge_inv_global, edge_inv_local\n",
- "\n",
- " def forward(\n",
- " self,\n",
- " sample,\n",
- " timestep: Union[torch.Tensor, float, int],\n",
- " return_dict: bool = True,\n",
- " sigma=1.0,\n",
- " global_start_sigma=0.5,\n",
- " w_global=1.0,\n",
- " extend_order=False,\n",
- " extend_radius=True,\n",
- " clip_local=None,\n",
- " clip_global=1000.0,\n",
- " ) -> Union[MoleculeGNNOutput, Tuple]:\n",
- " r\"\"\"\n",
- " Args:\n",
- " sample: packed torch geometric object\n",
- " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
- " return_dict (`bool`, *optional*, defaults to `True`):\n",
- " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
- " Returns:\n",
- " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n",
- " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n",
- " \"\"\"\n",
- "\n",
- " # unpack sample\n",
- " atom_type = sample.atom_type\n",
- " bond_index = sample.edge_index\n",
- " bond_type = sample.edge_type\n",
- " num_graphs = sample.num_graphs\n",
- " pos = sample.pos\n",
- "\n",
- " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n",
- "\n",
- " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n",
- " atom_type=atom_type,\n",
- " pos=sample.pos,\n",
- " bond_index=bond_index,\n",
- " bond_type=bond_type,\n",
- " batch=sample.batch,\n",
- " time_step=timesteps,\n",
- " return_edges=True,\n",
- " extend_order=extend_order,\n",
- " extend_radius=extend_radius,\n",
- " ) # (E_global, 1), (E_local, 1)\n",
- "\n",
- " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n",
- " node_eq_local = graph_field_network(\n",
- " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n",
- " )\n",
- " if clip_local is not None:\n",
- " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n",
- "\n",
- " # Global\n",
- " if sigma < global_start_sigma:\n",
- " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n",
- " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n",
- " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n",
- " else:\n",
- " node_eq_global = 0\n",
- "\n",
- " # Sum\n",
- " eps_pos = node_eq_local + node_eq_global * w_global\n",
- "\n",
- " if not return_dict:\n",
- " return (-eps_pos,)\n",
- "\n",
- " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
- ],
- "metadata": {
- "id": "MCeZA1qQXzoK"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "CCIrPYSJj9wd"
- },
- "source": [
- "### Load pretrained model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "YdrAr6Ch--Ab"
- },
- "source": [
- "#### Load a model\n",
- "The model used is a design an\n",
- "equivariant convolutional layer, named graph field network (GFN).\n",
- "\n",
- "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "DyCo0nsqjbml",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 172,
- "referenced_widgets": [
- "d90f304e9560472eacfbdd11e46765eb",
- "1c6246f15b654f4daa11c9bcf997b78c",
- "c2321b3bff6f490ca12040a20308f555",
- "b7feb522161f4cf4b7cc7c1a078ff12d",
- "e2d368556e494ae7ae4e2e992af2cd4f",
- "bbef741e76ec41b7ab7187b487a383df",
- "561f742d418d4721b0670cc8dd62e22c",
- "872915dd1bb84f538c44e26badabafdd",
- "d022575f1fa2446d891650897f187b4d",
- "fdc393f3468c432aa0ada05e238a5436",
- "2c9362906e4b40189f16d14aa9a348da",
- "6010fc8daa7a44d5aec4b830ec2ebaa1",
- "7e0bb1b8d65249d3974200686b193be2",
- "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
- "6526646be5ed415c84d1245b040e629b",
- "24d31fc3576e43dd9f8301d2ef3a37ab",
- "2918bfaadc8d4b1a9832522c40dfefb8",
- "a4bfdca35cc54dae8812720f1b276a08",
- "e4901541199b45c6a18824627692fc39",
- "f915cf874246446595206221e900b2fe",
- "a9e388f22a9742aaaf538e22575c9433",
- "42f6c3db29d7484ba6b4f73590abd2f4"
- ]
+ "_ngl_msg_archive": [
+ {
+ "args": [
+ {
+ "binary": false,
+ "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n",
+ "type": "blob"
+ }
+ ],
+ "kwargs": {
+ "defaultRepresentation": true,
+ "ext": "pdb"
},
- "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45"
+ "methodName": "loadFile",
+ "reconstruc_color_scheme": false,
+ "target": "Stage",
+ "type": "call_method"
+ }
+ ],
+ "_ngl_original_stage_parameters": {
+ "ambientColor": 14540253,
+ "ambientIntensity": 0.2,
+ "backgroundColor": "white",
+ "cameraEyeSep": 0.3,
+ "cameraFov": 40,
+ "cameraType": "perspective",
+ "clipDist": 10,
+ "clipFar": 100,
+ "clipNear": 0,
+ "fogFar": 100,
+ "fogNear": 50,
+ "hoverTimeout": 0,
+ "impostor": true,
+ "lightColor": 14540253,
+ "lightIntensity": 1,
+ "mousePreset": "default",
+ "panSpeed": 1,
+ "quality": "medium",
+ "rotateSpeed": 2,
+ "sampleLevel": 0,
+ "tooltip": true,
+ "workerDefault": true,
+ "zoomSpeed": 1.2
},
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "Downloading: 0%| | 0.00/3.27M [00:00, ?B/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "d90f304e9560472eacfbdd11e46765eb"
- }
+ "_ngl_repr_dict": {
+ "0": {
+ "0": {
+ "params": {
+ "aspectRatio": 1.5,
+ "assembly": "default",
+ "bondScale": 0.3,
+ "bondSpacing": 0.75,
+ "clipCenter": {
+ "x": 0,
+ "y": 0,
+ "z": 0
},
- "metadata": {}
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "Downloading: 0%| | 0.00/401 [00:00, ?B/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "6010fc8daa7a44d5aec4b830ec2ebaa1"
- }
+ "clipNear": 0,
+ "clipRadius": 0,
+ "colorMode": "hcl",
+ "colorReverse": false,
+ "colorScale": "",
+ "colorScheme": "element",
+ "colorValue": 9474192,
+ "cylinderOnly": false,
+ "defaultAssembly": "",
+ "depthWrite": true,
+ "diffuse": 16777215,
+ "diffuseInterior": false,
+ "disableImpostor": false,
+ "disablePicking": false,
+ "flatShaded": false,
+ "interiorColor": 2236962,
+ "interiorDarkening": 0,
+ "lazy": false,
+ "lineOnly": false,
+ "linewidth": 2,
+ "matrix": {
+ "elements": [
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1
+ ]
},
- "metadata": {}
- },
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
- "Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\n",
- "- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
- ]
+ "metalness": 0,
+ "multipleBond": "off",
+ "opacity": 1,
+ "openEnded": true,
+ "quality": "high",
+ "radialSegments": 20,
+ "radiusData": {},
+ "radiusScale": 2,
+ "radiusSize": 0.15,
+ "radiusType": "size",
+ "roughness": 0.4,
+ "sele": "",
+ "side": "double",
+ "sphereDetail": 2,
+ "useInteriorColor": true,
+ "visible": true,
+ "wireframe": false
+ },
+ "type": "ball+stick"
}
- ],
- "source": [
- "DEVICE = 'cuda'\n",
- "model = MoleculeGNN.from_pretrained(\"fusing/gfn-molecule-gen-drugs\").to(DEVICE)"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "The warnings above are because the pre-trained model was uploaded before cleaning the code!"
- ],
- "metadata": {
- "id": "HdclRaqoUWUD"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PlOkPySoJ1m9"
- },
- "source": [
- "#### Create scheduler\n",
- "Note, other schedulers are used in the paper for slightly improved performance over DDPM."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "nNHnIk9CkAb2"
- },
- "outputs": [],
- "source": [
- "from diffusers import DDPMScheduler"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "RnDJdDBztjFF"
- },
- "outputs": [],
- "source": [
- "num_timesteps = 1000\n",
- "scheduler = DDPMScheduler(num_train_timesteps=num_timesteps,beta_schedule=\"sigmoid\",beta_start=1e-7, beta_end=2e-3, clip_sample=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1vh3fpSAflkL"
- },
- "source": [
- "### Get a dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "B6qzaGjVKFVk"
- },
- "source": [
- "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n",
- "\n",
- "(direct downloading from the hub does not yet work for this datatype)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jbLl3EJdgj3x"
- },
- "outputs": [],
- "source": [
- "# from google.colab import files"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "E591lVuTgxPE"
- },
- "outputs": [],
- "source": [
- "# uploaded = files.upload()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KUNxfK3ln98Q"
- },
- "source": [
- "Load the dataset with torch."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "7L4iOShTpcQX",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "--2022-10-12 18:32:19-- https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
- "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n",
- "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n",
- "HTTP request sent, awaiting response... 200 OK\n",
- "Length: 127774 (125K) [application/octet-stream]\n",
- "Saving to: ‘molecules.pkl’\n",
- "\n",
- "molecules.pkl 100%[===================>] 124.78K 180KB/s in 0.7s \n",
- "\n",
- "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n",
- "\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "\n",
- "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
- "dataset = torch.load('/content/molecules.pkl')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "QZcmy1EvKQRk"
- },
- "source": [
- "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "JVjz6iH_H6Eh",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")"
- ]
+ },
+ "1": {
+ "0": {
+ "params": {
+ "aspectRatio": 1.5,
+ "assembly": "default",
+ "bondScale": 0.3,
+ "bondSpacing": 0.75,
+ "clipCenter": {
+ "x": 0,
+ "y": 0,
+ "z": 0
},
- "metadata": {},
- "execution_count": 20
- }
- ],
- "source": [
- "dataset[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Run the diffusion process"
- ],
- "metadata": {
- "id": "vHNiZAUxNgoy"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jZ1KZrxKqENg"
- },
- "source": [
- "#### Helper Functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "s240tYueqKKf"
- },
- "outputs": [],
- "source": [
- "from torch_geometric.data import Data, Batch\n",
- "from torch_scatter import scatter_add, scatter_mean\n",
- "from tqdm import tqdm\n",
- "import copy\n",
- "import os\n",
- "\n",
- "def repeat_data(data: Data, num_repeat) -> Batch:\n",
- " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n",
- " return Batch.from_data_list(datas)\n",
- "\n",
- "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n",
- " datas = batch.to_data_list()\n",
- " new_data = []\n",
- " for i in range(num_repeat):\n",
- " new_data += copy.deepcopy(datas)\n",
- " return Batch.from_data_list(new_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AMnQTk0eqT7Z"
- },
- "source": [
- "#### Constants"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "WYGkzqgzrHmF"
- },
- "outputs": [],
- "source": [
- "num_samples = 1 # solutions per molecule\n",
- "num_molecules = 3\n",
- "\n",
- "DEVICE = 'cuda'\n",
- "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n",
- "# constants for inference\n",
- "w_global = 0.5 #0,.3 for qm9\n",
- "global_start_sigma = 0.5\n",
- "eta = 1.0\n",
- "clip_local = None\n",
- "clip_pos = None\n",
- "\n",
- "# constands for data handling\n",
- "save_traj = False\n",
- "save_data = False\n",
- "output_dir = '/content/'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-xD5bJ3SqM7t"
- },
- "source": [
- "#### Generate samples!\n",
- "Note that the 3d representation of a molecule is referred to as the **conformation**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "x9xuLUNg26z1",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " after removing the cwd from sys.path.\n",
- "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
- ]
- }
- ],
- "source": [
- "results = []\n",
- "\n",
- "# define sigmas\n",
- "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n",
- "sigmas = sigmas.to(DEVICE)\n",
- "\n",
- "for count, data in enumerate(tqdm(dataset)):\n",
- " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n",
- "\n",
- " data_input = data.clone()\n",
- " data_input['pos_ref'] = None\n",
- " batch = repeat_data(data_input, num_samples).to(DEVICE)\n",
- "\n",
- " # initial configuration\n",
- " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n",
- "\n",
- " # for logging animation of denoising\n",
- " pos_traj = []\n",
- " with torch.no_grad():\n",
- "\n",
- " # scale initial sample\n",
- " pos = pos_init * sigmas[-1]\n",
- " for t in scheduler.timesteps:\n",
- " batch.pos = pos\n",
- "\n",
- " # generate geometry with model, then filter it\n",
- " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n",
- "\n",
- " # Update\n",
- " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n",
- "\n",
- " pos = reconstructed_pos\n",
- "\n",
- " if torch.isnan(pos).any():\n",
- " print(\"NaN detected. Please restart.\")\n",
- " raise FloatingPointError()\n",
- "\n",
- " # recenter graph of positions for next iteration\n",
- " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n",
- "\n",
- " # optional clipping\n",
- " if clip_pos is not None:\n",
- " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n",
- " pos_traj.append(pos.clone().cpu())\n",
- "\n",
- " pos_gen = pos.cpu()\n",
- " if save_traj:\n",
- " pos_gen_traj = pos_traj.cpu()\n",
- " data.pos_gen = torch.stack(pos_gen_traj)\n",
- " else:\n",
- " data.pos_gen = pos_gen\n",
- " results.append(data)\n",
- "\n",
- "\n",
- "if save_data:\n",
- " save_path = os.path.join(output_dir, 'samples_all.pkl')\n",
- "\n",
- " with open(save_path, 'wb') as f:\n",
- " pickle.dump(results, f)"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Render the results!"
- ],
- "metadata": {
- "id": "fSApwSaZNndW"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "d47Zxo2OKdgZ"
- },
- "source": [
- "This function allows us to render 3d in colab."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "e9Cd0kCAv9b8"
- },
- "outputs": [],
- "source": [
- "from google.colab import output\n",
- "output.enable_custom_widget_manager()"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Helper functions"
- ],
- "metadata": {
- "id": "RjaVuR15NqzF"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "28rBYa9NKhlz"
- },
- "source": [
- "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "LKdKdwxcyTQ6"
- },
- "outputs": [],
- "source": [
- "from copy import deepcopy\n",
- "def set_rdmol_positions(rdkit_mol, pos):\n",
- " \"\"\"\n",
- " Args:\n",
- " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
- " pos: (N_atoms, 3)\n",
- " \"\"\"\n",
- " mol = deepcopy(rdkit_mol)\n",
- " set_rdmol_positions_(mol, pos)\n",
- " return mol\n",
- "\n",
- "def set_rdmol_positions_(mol, pos):\n",
- " \"\"\"\n",
- " Args:\n",
- " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
- " pos: (N_atoms, 3)\n",
- " \"\"\"\n",
- " for i in range(pos.shape[0]):\n",
- " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n",
- " return mol\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NuE10hcpKmzK"
- },
- "source": [
- "Process the generated data to make it easy to view."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "KieVE1vc0_Vs",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "collect 5 generated molecules in `mols`\n"
- ]
- }
- ],
- "source": [
- "# the model can generate multiple conformations per 2d geometry\n",
- "num_gen = results[0]['pos_gen'].shape[0]\n",
- "\n",
- "# init storage objects\n",
- "mols_gen = []\n",
- "mols_orig = []\n",
- "for to_process in results:\n",
- "\n",
- " # store the reference 3d position\n",
- " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n",
- "\n",
- " # store the generated 3d position\n",
- " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n",
- "\n",
- " # copy data to new object\n",
- " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n",
- "\n",
- " # append results\n",
- " mols_gen.append(new_mol)\n",
- " mols_orig.append(to_process.rdmol)\n",
- "\n",
- "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tin89JwMKp4v"
- },
- "source": [
- "Import tools to visualize the 2d chemical diagram of the molecule."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "yqV6gllSZn38"
- },
- "outputs": [],
- "source": [
- "from rdkit.Chem import AllChem\n",
- "from rdkit import Chem\n",
- "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n",
- "from IPython.display import SVG, display"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TFNKmGddVoOk"
- },
- "source": [
- "Select molecule to visualize"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "KzuwLlrrVaGc"
- },
- "outputs": [],
- "source": [
- "idx = 0\n",
- "assert idx < len(results), \"selected molecule that was not generated\""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Viewing"
- ],
- "metadata": {
- "id": "hkb8w0_SNtU8"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "I3R4QBQeKttN"
- },
- "source": [
- "This 2D rendering is the equivalent of the **input to the model**!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "gkQRWjraaKex",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 321
- },
- "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- ""
- ],
- "image/svg+xml": ""
+ "clipNear": 0,
+ "clipRadius": 0,
+ "colorMode": "hcl",
+ "colorReverse": false,
+ "colorScale": "",
+ "colorScheme": "element",
+ "colorValue": 9474192,
+ "cylinderOnly": false,
+ "defaultAssembly": "",
+ "depthWrite": true,
+ "diffuse": 16777215,
+ "diffuseInterior": false,
+ "disableImpostor": false,
+ "disablePicking": false,
+ "flatShaded": false,
+ "interiorColor": 2236962,
+ "interiorDarkening": 0,
+ "lazy": false,
+ "lineOnly": false,
+ "linewidth": 2,
+ "matrix": {
+ "elements": [
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1
+ ]
},
- "metadata": {}
+ "metalness": 0,
+ "multipleBond": "off",
+ "opacity": 1,
+ "openEnded": true,
+ "quality": "high",
+ "radialSegments": 20,
+ "radiusData": {},
+ "radiusScale": 2,
+ "radiusSize": 0.15,
+ "radiusType": "size",
+ "roughness": 0.4,
+ "sele": "",
+ "side": "double",
+ "sphereDetail": 2,
+ "useInteriorColor": true,
+ "visible": true,
+ "wireframe": false
+ },
+ "type": "ball+stick"
}
- ],
- "source": [
- "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n",
- "molSize=(450,300)\n",
- "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n",
- "drawer.DrawMolecule(mc)\n",
- "drawer.FinishDrawing()\n",
- "svg = drawer.GetDrawingText()\n",
- "display(SVG(svg.replace('svg:','')))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "z4FDMYMxKw2I"
+ }
},
- "source": [
- "Generate the 3d molecule!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "aT1Bkb8YxJfV",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 17,
- "referenced_widgets": [
- "695ab5bbf30a4ab19df1f9f33469f314",
- "eac6a8dcdc9d4335a2e51031793ead29"
- ]
- },
- "outputId": "b98870ae-049d-4386-b676-166e9526bda2"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "695ab5bbf30a4ab19df1f9f33469f314"
- }
- },
- "metadata": {
- "application/vnd.jupyter.widget-view+json": {
- "colab": {
- "custom_widget_manager": {
- "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
- }
- }
- }
- }
- }
+ "_ngl_serialize": false,
+ "_ngl_version": "",
+ "_ngl_view_id": [
+ "FB989FD1-5B9C-446B-8914-6B58AF85446D"
],
- "source": [
- "from nglview import show_rdkit as show"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "pxtq8I-I18C-",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 337,
- "referenced_widgets": [
- "be446195da2b4ff2aec21ec5ff963a54",
- "c6596896148b4a8a9c57963b67c7782f",
- "2489b5e5648541fbbdceadb05632a050",
- "01e0ba4e5da04914b4652b8d58565d7b",
- "c30e6c2f3e2a44dbbb3d63bd519acaa4",
- "f31c6e40e9b2466a9064a2669933ecd5",
- "19308ccac642498ab8b58462e3f1b0bb",
- "4a081cdc2ec3421ca79dd933b7e2b0c4",
- "e5c0d75eb5e1447abd560c8f2c6017e1",
- "5146907ef6764654ad7d598baebc8b58",
- "144ec959b7604a2cabb5ca46ae5e5379",
- "abce2a80e6304df3899109c6d6cac199",
- "65195cb7a4134f4887e9dd19f3676462"
- ]
- },
- "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "NGLWidget()"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "be446195da2b4ff2aec21ec5ff963a54"
- }
- },
- "metadata": {
- "application/vnd.jupyter.widget-view+json": {
- "colab": {
- "custom_widget_manager": {
- "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
- }
- }
- }
- }
- }
+ "_player_dict": {},
+ "_scene_position": {},
+ "_scene_rotation": {},
+ "_synced_model_ids": [],
+ "_synced_repr_model_ids": [],
+ "_view_count": null,
+ "_view_height": "",
+ "_view_module": "nglview-js-widgets",
+ "_view_module_version": "3.0.1",
+ "_view_name": "NGLView",
+ "_view_width": "",
+ "background": "white",
+ "frame": 0,
+ "gui_style": null,
+ "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f",
+ "max_frame": 0,
+ "n_components": 2,
+ "picked": {}
+ }
+ },
+ "c2321b3bff6f490ca12040a20308f555": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd",
+ "max": 3271865,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d",
+ "value": 3271865
+ }
+ },
+ "c30e6c2f3e2a44dbbb3d63bd519acaa4": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c6596896148b4a8a9c57963b67c7782f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d022575f1fa2446d891650897f187b4d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "d90f304e9560472eacfbdd11e46765eb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c",
+ "IPY_MODEL_c2321b3bff6f490ca12040a20308f555",
+ "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d"
],
- "source": [
- "# new molecule\n",
- "show(mols_gen[idx])"
- ]
- },
- {
- "cell_type": "code",
- "source": [],
- "metadata": {
- "id": "KJr4h2mwXeTo"
- },
- "execution_count": null,
- "outputs": []
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "provenance": []
- },
- "gpuClass": "standard",
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- },
- "widgets": {
- "application/vnd.jupyter.widget-state+json": {
- "d90f304e9560472eacfbdd11e46765eb": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c",
- "IPY_MODEL_c2321b3bff6f490ca12040a20308f555",
- "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d"
- ],
- "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f"
- }
- },
- "1c6246f15b654f4daa11c9bcf997b78c": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df",
- "placeholder": "",
- "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c",
- "value": "Downloading: 100%"
- }
- },
- "c2321b3bff6f490ca12040a20308f555": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd",
- "max": 3271865,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d",
- "value": 3271865
- }
- },
- "b7feb522161f4cf4b7cc7c1a078ff12d": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436",
- "placeholder": "",
- "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da",
- "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]"
- }
- },
- "e2d368556e494ae7ae4e2e992af2cd4f": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "bbef741e76ec41b7ab7187b487a383df": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "561f742d418d4721b0670cc8dd62e22c": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "872915dd1bb84f538c44e26badabafdd": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "d022575f1fa2446d891650897f187b4d": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "fdc393f3468c432aa0ada05e238a5436": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2c9362906e4b40189f16d14aa9a348da": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "6010fc8daa7a44d5aec4b830ec2ebaa1": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2",
- "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
- "IPY_MODEL_6526646be5ed415c84d1245b040e629b"
- ],
- "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab"
- }
- },
- "7e0bb1b8d65249d3974200686b193be2": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8",
- "placeholder": "",
- "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08",
- "value": "Downloading: 100%"
- }
- },
- "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39",
- "max": 401,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_f915cf874246446595206221e900b2fe",
- "value": 401
- }
- },
- "6526646be5ed415c84d1245b040e629b": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433",
- "placeholder": "",
- "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4",
- "value": " 401/401 [00:00<00:00, 13.5kB/s]"
- }
- },
- "24d31fc3576e43dd9f8301d2ef3a37ab": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2918bfaadc8d4b1a9832522c40dfefb8": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "a4bfdca35cc54dae8812720f1b276a08": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "e4901541199b45c6a18824627692fc39": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "f915cf874246446595206221e900b2fe": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "a9e388f22a9742aaaf538e22575c9433": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "42f6c3db29d7484ba6b4f73590abd2f4": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "695ab5bbf30a4ab19df1f9f33469f314": {
- "model_module": "nglview-js-widgets",
- "model_name": "ColormakerRegistryModel",
- "model_module_version": "3.0.1",
- "state": {
- "_dom_classes": [],
- "_model_module": "nglview-js-widgets",
- "_model_module_version": "3.0.1",
- "_model_name": "ColormakerRegistryModel",
- "_msg_ar": [],
- "_msg_q": [],
- "_ready": false,
- "_view_count": null,
- "_view_module": "nglview-js-widgets",
- "_view_module_version": "3.0.1",
- "_view_name": "ColormakerRegistryView",
- "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29"
- }
- },
- "eac6a8dcdc9d4335a2e51031793ead29": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "be446195da2b4ff2aec21ec5ff963a54": {
- "model_module": "nglview-js-widgets",
- "model_name": "NGLModel",
- "model_module_version": "3.0.1",
- "state": {
- "_camera_orientation": [
- -15.519693580202304,
- -14.065056548036177,
- -23.53197484807691,
- 0,
- -23.357853515109753,
- 20.94055073042662,
- 2.888695042134944,
- 0,
- 14.352363398292777,
- 18.870825741878015,
- -20.744689572909344,
- 0,
- 0.2724999189376831,
- 0.6940000057220459,
- -0.3734999895095825,
- 1
- ],
- "_camera_str": "orthographic",
- "_dom_classes": [],
- "_gui_theme": null,
- "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050",
- "_igui": null,
- "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b",
- "_model_module": "nglview-js-widgets",
- "_model_module_version": "3.0.1",
- "_model_name": "NGLModel",
- "_ngl_color_dict": {},
- "_ngl_coordinate_resource": {},
- "_ngl_full_stage_parameters": {
- "impostor": true,
- "quality": "medium",
- "workerDefault": true,
- "sampleLevel": 0,
- "backgroundColor": "white",
- "rotateSpeed": 2,
- "zoomSpeed": 1.2,
- "panSpeed": 1,
- "clipNear": 0,
- "clipFar": 100,
- "clipDist": 10,
- "fogNear": 50,
- "fogFar": 100,
- "cameraFov": 40,
- "cameraEyeSep": 0.3,
- "cameraType": "perspective",
- "lightColor": 14540253,
- "lightIntensity": 1,
- "ambientColor": 14540253,
- "ambientIntensity": 0.2,
- "hoverTimeout": 0,
- "tooltip": true,
- "mousePreset": "default"
- },
- "_ngl_msg_archive": [
- {
- "target": "Stage",
- "type": "call_method",
- "methodName": "loadFile",
- "reconstruc_color_scheme": false,
- "args": [
- {
- "type": "blob",
- "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n",
- "binary": false
- }
- ],
- "kwargs": {
- "defaultRepresentation": true,
- "ext": "pdb"
- }
- }
- ],
- "_ngl_original_stage_parameters": {
- "impostor": true,
- "quality": "medium",
- "workerDefault": true,
- "sampleLevel": 0,
- "backgroundColor": "white",
- "rotateSpeed": 2,
- "zoomSpeed": 1.2,
- "panSpeed": 1,
- "clipNear": 0,
- "clipFar": 100,
- "clipDist": 10,
- "fogNear": 50,
- "fogFar": 100,
- "cameraFov": 40,
- "cameraEyeSep": 0.3,
- "cameraType": "perspective",
- "lightColor": 14540253,
- "lightIntensity": 1,
- "ambientColor": 14540253,
- "ambientIntensity": 0.2,
- "hoverTimeout": 0,
- "tooltip": true,
- "mousePreset": "default"
- },
- "_ngl_repr_dict": {
- "0": {
- "0": {
- "type": "ball+stick",
- "params": {
- "lazy": false,
- "visible": true,
- "quality": "high",
- "sphereDetail": 2,
- "radialSegments": 20,
- "openEnded": true,
- "disableImpostor": false,
- "aspectRatio": 1.5,
- "lineOnly": false,
- "cylinderOnly": false,
- "multipleBond": "off",
- "bondScale": 0.3,
- "bondSpacing": 0.75,
- "linewidth": 2,
- "radiusType": "size",
- "radiusData": {},
- "radiusSize": 0.15,
- "radiusScale": 2,
- "assembly": "default",
- "defaultAssembly": "",
- "clipNear": 0,
- "clipRadius": 0,
- "clipCenter": {
- "x": 0,
- "y": 0,
- "z": 0
- },
- "flatShaded": false,
- "opacity": 1,
- "depthWrite": true,
- "side": "double",
- "wireframe": false,
- "colorScheme": "element",
- "colorScale": "",
- "colorReverse": false,
- "colorValue": 9474192,
- "colorMode": "hcl",
- "roughness": 0.4,
- "metalness": 0,
- "diffuse": 16777215,
- "diffuseInterior": false,
- "useInteriorColor": true,
- "interiorColor": 2236962,
- "interiorDarkening": 0,
- "matrix": {
- "elements": [
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1
- ]
- },
- "disablePicking": false,
- "sele": ""
- }
- }
- },
- "1": {
- "0": {
- "type": "ball+stick",
- "params": {
- "lazy": false,
- "visible": true,
- "quality": "high",
- "sphereDetail": 2,
- "radialSegments": 20,
- "openEnded": true,
- "disableImpostor": false,
- "aspectRatio": 1.5,
- "lineOnly": false,
- "cylinderOnly": false,
- "multipleBond": "off",
- "bondScale": 0.3,
- "bondSpacing": 0.75,
- "linewidth": 2,
- "radiusType": "size",
- "radiusData": {},
- "radiusSize": 0.15,
- "radiusScale": 2,
- "assembly": "default",
- "defaultAssembly": "",
- "clipNear": 0,
- "clipRadius": 0,
- "clipCenter": {
- "x": 0,
- "y": 0,
- "z": 0
- },
- "flatShaded": false,
- "opacity": 1,
- "depthWrite": true,
- "side": "double",
- "wireframe": false,
- "colorScheme": "element",
- "colorScale": "",
- "colorReverse": false,
- "colorValue": 9474192,
- "colorMode": "hcl",
- "roughness": 0.4,
- "metalness": 0,
- "diffuse": 16777215,
- "diffuseInterior": false,
- "useInteriorColor": true,
- "interiorColor": 2236962,
- "interiorDarkening": 0,
- "matrix": {
- "elements": [
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1
- ]
- },
- "disablePicking": false,
- "sele": ""
- }
- }
- }
- },
- "_ngl_serialize": false,
- "_ngl_version": "",
- "_ngl_view_id": [
- "FB989FD1-5B9C-446B-8914-6B58AF85446D"
- ],
- "_player_dict": {},
- "_scene_position": {},
- "_scene_rotation": {},
- "_synced_model_ids": [],
- "_synced_repr_model_ids": [],
- "_view_count": null,
- "_view_height": "",
- "_view_module": "nglview-js-widgets",
- "_view_module_version": "3.0.1",
- "_view_name": "NGLView",
- "_view_width": "",
- "background": "white",
- "frame": 0,
- "gui_style": null,
- "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f",
- "max_frame": 0,
- "n_components": 2,
- "picked": {}
- }
- },
- "c6596896148b4a8a9c57963b67c7782f": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2489b5e5648541fbbdceadb05632a050": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ButtonModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ButtonModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ButtonView",
- "button_style": "",
- "description": "",
- "disabled": false,
- "icon": "compress",
- "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199",
- "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462",
- "tooltip": ""
- }
- },
- "01e0ba4e5da04914b4652b8d58565d7b": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1",
- "IPY_MODEL_5146907ef6764654ad7d598baebc8b58"
- ],
- "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379"
- }
- },
- "c30e6c2f3e2a44dbbb3d63bd519acaa4": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "f31c6e40e9b2466a9064a2669933ecd5": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "19308ccac642498ab8b58462e3f1b0bb": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "4a081cdc2ec3421ca79dd933b7e2b0c4": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "SliderStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "SliderStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": "",
- "handle_color": null
- }
- },
- "e5c0d75eb5e1447abd560c8f2c6017e1": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "PlayModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "PlayModel",
- "_playing": false,
- "_repeat": false,
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "PlayView",
- "description": "",
- "description_tooltip": null,
- "disabled": false,
- "interval": 100,
- "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4",
- "max": 0,
- "min": 0,
- "show_repeat": true,
- "step": 1,
- "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5",
- "value": 0
- }
- },
- "5146907ef6764654ad7d598baebc8b58": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "IntSliderModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "IntSliderModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "IntSliderView",
- "continuous_update": true,
- "description": "",
- "description_tooltip": null,
- "disabled": false,
- "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb",
- "max": 0,
- "min": 0,
- "orientation": "horizontal",
- "readout": true,
- "readout_format": "d",
- "step": 1,
- "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4",
- "value": 0
- }
- },
- "144ec959b7604a2cabb5ca46ae5e5379": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "abce2a80e6304df3899109c6d6cac199": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": "34px"
- }
- },
- "65195cb7a4134f4887e9dd19f3676462": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ButtonStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ButtonStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "button_color": null,
- "font_weight": ""
- }
- }
- }
+ "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f"
+ }
+ },
+ "e2d368556e494ae7ae4e2e992af2cd4f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e4901541199b45c6a18824627692fc39": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e5c0d75eb5e1447abd560c8f2c6017e1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "PlayModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "PlayModel",
+ "_playing": false,
+ "_repeat": false,
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "PlayView",
+ "description": "",
+ "description_tooltip": null,
+ "disabled": false,
+ "interval": 100,
+ "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4",
+ "max": 0,
+ "min": 0,
+ "show_repeat": true,
+ "step": 1,
+ "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5",
+ "value": 0
+ }
+ },
+ "eac6a8dcdc9d4335a2e51031793ead29": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f31c6e40e9b2466a9064a2669933ecd5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "f915cf874246446595206221e900b2fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "fdc393f3468c432aa0ada05e238a5436": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
}
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
\ No newline at end of file
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb
index 571f1a0323a2..315aee710594 100644
--- a/examples/research_projects/gligen/demo.ipynb
+++ b/examples/research_projects/gligen/demo.ipynb
@@ -26,8 +26,7 @@
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
- "import torch\n",
- "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
+ "from diffusers import StableDiffusionGLIGENPipeline"
]
},
{
@@ -36,28 +35,25 @@
"metadata": {},
"outputs": [],
"source": [
- "import os\n",
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
+ "\n",
"import diffusers\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDPMScheduler,\n",
- " UNet2DConditionModel,\n",
- " UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
+ " UNet2DConditionModel,\n",
")\n",
- "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
+ "\n",
+ "\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
- "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
+ "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
- "text_encoder = CLIPTextModel.from_pretrained(\n",
- " pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
- ")\n",
- "vae = AutoencoderKL.from_pretrained(\n",
- " pretrained_model_name_or_path, subfolder=\"vae\"\n",
- ")\n",
+ "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
+ "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
@@ -71,9 +67,7 @@
"metadata": {},
"outputs": [],
"source": [
- "unet = UNet2DConditionModel.from_pretrained(\n",
- " '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
- ")"
+ "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
]
},
{
@@ -108,6 +102,9 @@
"metadata": {},
"outputs": [],
"source": [
+ "import numpy as np\n",
+ "\n",
+ "\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n",
@@ -117,10 +114,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
- "prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
- "gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
- "\n",
- "import numpy as np\n",
+ "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
+ "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n",
@@ -166,7 +161,7 @@
"metadata": {},
"outputs": [],
"source": [
- "diffusers.utils.make_image_grid(images, 4, len(images)//4)"
+ "diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
]
},
{
@@ -179,7 +174,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "densecaption",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -197,5 +192,5 @@
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
}
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index 1d9203be7e01..f94b1dd6b5c4 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -15,8 +15,8 @@
# limitations under the License.
"""
- Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
- Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
+Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index 0f507b26d6a8..57c555e43fd8 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",")
- assert all(
- x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
- ), "Instance data dir and prompt inputs are not of the same length."
+ assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
+ "Instance data dir and prompt inputs are not of the same length."
+ )
if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts
- assert num_of_validation_prompts == len(
- negative_validation_prompts
- ), "The length of negative prompts for validation is greater than the number of validation prompts."
+ assert num_of_validation_prompts == len(negative_validation_prompts), (
+ "The length of negative prompts for validation is greater than the number of validation prompts."
+ )
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 19432142f541..75dcfccbd5b8 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index 7f5dc8ece9fc..a881b06a94dc 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 19c1f30d82da..51668a61cdc2 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -663,8 +663,7 @@ def check_inputs(
self.check_image(image, prompt, prompt_embeds)
else:
raise ValueError(
- f"You have passed a list of images of length {len(image_pair)}."
- f"Make sure the list size equals to two."
+ f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
)
# Check `controlnet_conditioning_scale`
diff --git a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
index 9719585d3dfb..6ae1a9a6c611 100644
--- a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
@@ -173,7 +173,7 @@ def print_loss_closure(step, loss):
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
- print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
+ print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return
@@ -622,7 +622,7 @@ def collate_fn(examples):
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
- print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
+ print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index 26caba5a42c1..043f913893b1 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -1057,7 +1057,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 410cd74a5b7b..393f991387d6 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -1021,7 +1021,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index c02a59a0077a..01ef67a55da4 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
-# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
@@ -1336,7 +1336,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
index abc439912664..c87f50e27245 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -750,7 +750,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
index f5bee58d4534..ebb9b129db7e 100644
--- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index 2061f0c6775b..539d4a6575b0 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -767,7 +767,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 757a12045f10..51e220828cdf 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index 11463943c448..f32c729195b0 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
- accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
- index_no_updates_2
- ] = orig_embeds_params_2[index_no_updates_2]
+ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
+ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
+ orig_embeds_params_2[index_no_updates_2]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py
index aa5d4c67b642..d13e102e7816 100644
--- a/examples/vqgan/test_vqgan.py
+++ b/examples/vqgan/test_vqgan.py
@@ -177,7 +177,7 @@ def test_vqmodel_checkpointing(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--seed=0
""".split()
@@ -262,7 +262,7 @@ def test_vqmodel_checkpointing_use_ema(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--use_ema
--seed=0
@@ -377,7 +377,7 @@ def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoi
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2
--seed=0
""".split()
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index 992722fa7a78..33d234da52d7 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -653,15 +653,15 @@ def main():
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
- assert isinstance(
- timm_centercrop_transform, transforms.CenterCrop
- ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
+ f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ )
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
- assert isinstance(
- timm_model_normalization, transforms.Normalize
- ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ assert isinstance(timm_model_normalization, transforms.Normalize), (
+ f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ )
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
diff --git a/pyproject.toml b/pyproject.toml
index 299865a1225d..a864ea34b888 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ line-length = 119
[tool.ruff.lint]
# Never enforce `E501` (line length violations).
-ignore = ["C901", "E501", "E741", "F402", "F823"]
+ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files.
diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py
index 21be29dfdb99..ddd1bf508b6d 100644
--- a/scripts/convert_amused.py
+++ b/scripts/convert_amused.py
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
- print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
+ print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
return new_vae
diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py
index 0f8b4ddca8ef..2b918280ca05 100644
--- a/scripts/convert_consistency_to_diffusers.py
+++ b/scripts/convert_consistency_to_diffusers.py
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer-1}.1"
+ old_prefix = f"output_blocks.{current_layer - 1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer-1}.2"
+ old_prefix = f"output_blocks.{current_layer - 1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py
index f9caa50dfc9b..e269a49070cc 100755
--- a/scripts/convert_dance_diffusion_to_diffusers.py
+++ b/scripts/convert_dance_diffusion_to_diffusers.py
@@ -261,9 +261,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
- assert (
- model_name == args.model_path
- ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ assert model_name == args.model_path, (
+ f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ )
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -290,9 +290,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
- assert (
- diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
- ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
+ f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ )
if key == "time_proj.weight":
value = value.squeeze()
diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py
index 648d0376f72e..1aa792b3f06a 100644
--- a/scripts/convert_diffusers_to_original_sdxl.py
+++ b/scripts/convert_diffusers_to_original_sdxl.py
@@ -52,18 +52,18 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2:
@@ -75,12 +75,12 @@
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
@@ -89,7 +89,7 @@
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -137,20 +137,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3-i}.upsample."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i+1}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
index d1b7df070c43..049dda7d42a7 100644
--- a/scripts/convert_diffusers_to_original_stable_diffusion.py
+++ b/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -47,36 +47,36 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -85,7 +85,7 @@
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -133,20 +133,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3-i}.upsample."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i+1}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
index 1c8383690890..5cef46c98983 100644
--- a/scripts/convert_hunyuandit_controlnet_to_diffusers.py
+++ b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
)
- model_config[
- "use_style_cond_and_image_meta_size"
- ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ model_config["use_style_cond_and_image_meta_size"] = (
+ args.use_style_cond_and_image_meta_size
+ ) ### version <= v1.1: True; version >= v1.2: False
print(model_config)
for key in state_dict:
diff --git a/scripts/convert_hunyuandit_to_diffusers.py b/scripts/convert_hunyuandit_to_diffusers.py
index da3af8333ee3..65fcccb22a1a 100644
--- a/scripts/convert_hunyuandit_to_diffusers.py
+++ b/scripts/convert_hunyuandit_to_diffusers.py
@@ -13,15 +13,14 @@ def main(args):
state_dict = state_dict[args.load_key]
except KeyError:
raise KeyError(
- f"{args.load_key} not found in the checkpoint."
- f"Please load from the following keys:{state_dict.keys()}"
+ f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
)
device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
- model_config[
- "use_style_cond_and_image_meta_size"
- ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ model_config["use_style_cond_and_image_meta_size"] = (
+ args.use_style_cond_and_image_meta_size
+ ) ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict:
diff --git a/scripts/convert_k_upscaler_to_diffusers.py b/scripts/convert_k_upscaler_to_diffusers.py
index 62abedd73785..cff845ef8099 100644
--- a/scripts/convert_k_upscaler_to_diffusers.py
+++ b/scripts/convert_k_upscaler_to_diffusers.py
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
- cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
- cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_prefix = f"{block_prefix}.{idx}"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"]
- assert (
- len(set(original_config["depths"])) == 1
- ), "UNet2DConditionModel currently do not support blocks with different number of layers"
+ assert len(set(original_config["depths"])) == 1, (
+ "UNet2DConditionModel currently do not support blocks with different number of layers"
+ )
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py
index 9727deeb6b0c..64e4f69eac17 100644
--- a/scripts/convert_mochi_to_diffusers.py
+++ b/scripts/convert_mochi_to_diffusers.py
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.0.weight"
+ f"blocks.0.{i + 1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.0.bias"
+ f"blocks.0.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.2.weight"
+ f"blocks.0.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.2.bias"
+ f"blocks.0.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.3.weight"
+ f"blocks.0.{i + 1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.3.bias"
+ f"blocks.0.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.5.weight"
+ f"blocks.0.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.5.bias"
+ f"blocks.0.{i + 1}.stack.5.bias"
)
# Convert up_blocks (MochiUpBlock3D)
@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3):
for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.0.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.0.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.2.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.2.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.3.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.3.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.5.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.5.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.proj.weight"
+ f"blocks.{block + 1}.proj.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
+ f"blocks.{block + 1}.proj.bias"
)
- new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.0.weight"
+ f"layers.{i + 1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.0.bias"
+ f"layers.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.2.weight"
+ f"layers.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.2.bias"
+ f"layers.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.3.weight"
+ f"layers.{i + 1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.3.bias"
+ f"layers.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.5.weight"
+ f"layers.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.5.bias"
+ f"layers.{i + 1}.stack.5.bias"
)
# Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.0.weight"
+ f"layers.{block + 4}.layers.0.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.0.bias"
+ f"layers.{block + 4}.layers.0.bias"
)
for i in range(down_block_layers[block]):
# Convert resnets
- new_state_dict[
- f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
- ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
+ )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.0.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.2.weight"
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.2.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
)
- new_state_dict[
- f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
- ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.3.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.5.weight"
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.5.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
)
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
# Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.0.weight"
+ f"layers.{i + 7}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.0.bias"
+ f"layers.{i + 7}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.2.weight"
+ f"layers.{i + 7}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.2.bias"
+ f"layers.{i + 7}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.3.weight"
+ f"layers.{i + 7}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.3.bias"
+ f"layers.{i + 7}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.5.weight"
+ f"layers.{i + 7}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.5.bias"
+ f"layers.{i + 7}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.attn.out.weight"
+ f"layers.{i + 7}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.attn.out.bias"
+ f"layers.{i + 7}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.norm.weight"
+ f"layers.{i + 7}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.norm.bias"
+ f"layers.{i + 7}.attn_block.norm.bias"
)
# Convert output layers
diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py
index 1dc7d739ea76..2c0695ce5595 100644
--- a/scripts/convert_original_audioldm2_to_diffusers.py
+++ b/scripts/convert_original_audioldm2_to_diffusers.py
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py
index 4f8e4f8f9f80..44183f1aea29 100644
--- a/scripts/convert_original_audioldm_to_diffusers.py
+++ b/scripts/convert_original_audioldm_to_diffusers.py
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py
index 61e5d16eea9e..00836fde2592 100644
--- a/scripts/convert_original_musicldm_to_diffusers.py
+++ b/scripts/convert_original_musicldm_to_diffusers.py
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py
index a0f9d0f87d90..b33c8b0608e7 100644
--- a/scripts/convert_stable_audio.py
+++ b/scripts/convert_stable_audio.py
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
- new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
+ new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
if "encoder" in new_key:
for i in range(3):
- new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
- new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
- new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
else:
for i in range(2, 5):
- new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
- new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
- new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
- new_key = new_key.replace(f"block.{idx-1}", "snake1")
+ new_key = new_key.replace(f"block.{idx - 1}", "snake1")
elif idx == num_autoencoder_layers + 2:
- new_key = new_key.replace(f"block.{idx-1}", "conv2")
+ new_key = new_key.replace(f"block.{idx - 1}", "conv2")
else:
new_key = new_key
diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py
index 3243ce294b26..e46410ccb3bd 100644
--- a/scripts/convert_svd_to_diffusers.py
+++ b/scripts/convert_svd_to_diffusers.py
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[
- f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
- ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+ new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
+ unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+ )
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[
- f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
- ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+ new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
+ unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+ )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py
index 7da6b4094986..fe62d18faff0 100644
--- a/scripts/convert_vq_diffusion_to_diffusers.py
+++ b/scripts/convert_vq_diffusion_to_diffusers.py
@@ -51,9 +51,9 @@
def vqvae_model_from_original_config(original_config):
- assert (
- original_config["target"] in PORTED_VQVAES
- ), f"{original_config['target']} has not yet been ported to diffusers."
+ assert original_config["target"] in PORTED_VQVAES, (
+ f"{original_config['target']} has not yet been ported to diffusers."
+ )
original_config = original_config["params"]
@@ -464,15 +464,15 @@ def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_p
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
- assert (
- original_diffusion_config["target"] in PORTED_DIFFUSIONS
- ), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
- assert (
- original_transformer_config["target"] in PORTED_TRANSFORMERS
- ), f"{original_transformer_config['target']} has not yet been ported to diffusers."
- assert (
- original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
- ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
+ assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
+ f"{original_diffusion_config['target']} has not yet been ported to diffusers."
+ )
+ assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
+ f"{original_transformer_config['target']} has not yet been ported to diffusers."
+ )
+ assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
+ f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
+ )
original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"]
diff --git a/setup.py b/setup.py
index fdc166a81ecf..7c15d650c78d 100644
--- a/setup.py
+++ b/setup.py
@@ -122,7 +122,7 @@
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
- "ruff==0.1.5",
+ "ruff==0.9.10",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 8ec95ed6fc8d..520815d122de 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -29,7 +29,7 @@
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
- "ruff": "ruff==0.1.5",
+ "ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index 21a1a70ff79b..025f52521485 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -295,8 +295,7 @@ def set_ip_adapter_scale(self, scale):
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to "
- f"{len(attn_processor.scale)} IP-Adapter."
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 5ec16ff299eb..791b7ae9b14f 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present.
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
- unet_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
- te_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
elif lora_name.startswith("lora_te2_"):
- te2_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
# Store alpha if present.
if lora_name_alpha in state_dict:
@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ )
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+ )
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ )
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+ )
## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ )
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+ )
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ )
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+ )
# context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 99a2f871c837..218394af2843 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -26,6 +26,7 @@
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
+ _import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
@@ -41,7 +42,6 @@
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
- _import_structure["auto_model"] = ["AutoModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index 741f7075d76d..ebc7d79aeb28 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -205,7 +205,7 @@ def load_state_dict(
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
)
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index a88ee6c9c9b8..5515a7885098 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -211,9 +211,9 @@ def _init_continuous_input(self, norm_type):
def _init_vectorized_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert (
- self.config.num_vector_embeds is not None
- ), "Transformer2DModel over discrete input must provide num_embed"
+ assert self.config.num_vector_embeds is not None, (
+ "Transformer2DModel over discrete input must provide num_embed"
+ )
self.height = self.config.sample_size
self.width = self.config.sample_size
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index 1616d94ff1ff..f80771381b50 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -791,7 +791,7 @@ def check_inputs(
if transcription is None:
if self.text_encoder_2.config.model_type == "vits":
- raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
+ raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list)
):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 16d3529ed38a..2c63aedd966b 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -657,7 +657,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -665,7 +665,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
index 15745ecca3f0..aaec454cc723 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
@@ -1130,7 +1130,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index bff625367bc9..64cc8e13f33f 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -507,7 +507,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -515,7 +515,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index 27b9e0cd45fa..ecd5a8967b4f 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -574,7 +574,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -582,7 +582,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py
index dc0071a494e3..8ea5eb7dd575 100644
--- a/src/diffusers/pipelines/free_noise_utils.py
+++ b/src/diffusers/pipelines/free_noise_utils.py
@@ -341,9 +341,9 @@ def _encode_prompt_free_noise(
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
- negative_prompt_interpolation_embeds[
- start_frame : end_frame + 1
- ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
+ negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
+ self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
+ )
prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index e653b8266f19..5f8db26eef54 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
_exclude_from_cpu_offload = ["prior_prior"]
def __init__(
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index cce5f0b3d5bc..769c834ec3cc 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -579,7 +579,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
index 75d272ac5140..40fac01f8f8a 100644
--- a/src/diffusers/pipelines/omnigen/processor_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -95,13 +95,13 @@ def process_multi_modal_prompt(self, text, input_images):
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids))
- assert unique_image_ids == list(
- range(1, len(unique_image_ids) + 1)
- ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
+ f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ )
# total images must be the same as the number of image tags
- assert (
- len(unique_image_ids) == len(input_images)
- ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+ assert len(unique_image_ids) == len(input_images), (
+ f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+ )
input_images = [input_images[x - 1] for x in image_ids]
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index bc7a4b57affd..6d89f16765a3 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -604,7 +604,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -612,7 +612,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
@@ -1340,7 +1340,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index 33abfb0be89f..db652989cfc1 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -683,7 +683,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -691,7 +691,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1191,7 +1191,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index fdf3df2f4d6a..8b06bdc9c969 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -737,7 +737,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -745,7 +745,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1509,7 +1509,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 55a9f47145a2..288f269a6563 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -575,7 +575,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index f5b430564ca1..89a403df8d65 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -323,9 +323,7 @@ def maybe_raise_or_warn(
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
- raise ValueError(
- f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
- )
+ raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py
index 9d9f9d9b2ab1..dd25945590cd 100644
--- a/src/diffusers/pipelines/shap_e/renderer.py
+++ b/src/diffusers/pipelines/shap_e/renderer.py
@@ -983,9 +983,9 @@ def decode_to_mesh(
fields = torch.cat(fields, dim=1)
fields = fields.float()
- assert (
- len(fields.shape) == 3 and fields.shape[-1] == 1
- ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
+ assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
+ f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
+ )
fields = fields.reshape(1, *([grid_size] * 3))
@@ -1039,9 +1039,9 @@ def decode_to_mesh(
textures = textures.float()
# 3.3 augument the mesh with texture data
- assert len(textures.shape) == 3 and textures.shape[-1] == len(
- texture_channels
- ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
+ assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
+ f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
+ )
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
index 5d773b614a5c..1b87c02df029 100644
--- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
@@ -584,7 +584,7 @@ def __call__(
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
- f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
+ f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index abcba926160a..dd659306e002 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -335,7 +335,7 @@ def _generate(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index ddd2e27dedaf..f2e1d87be87e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -475,7 +475,7 @@ def __call__(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 6f4e7f358952..0f7be1a1bbcd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -660,7 +660,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -668,7 +668,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1226,7 +1226,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index 7857bc58a8ad..e0748943ffff 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -401,7 +401,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index c6967bc393b5..42db88b03049 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -600,7 +600,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index dae4540ebe00..f9b6dcbf5ad2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -740,7 +740,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index c69fb90a4c5e..cac305a87f00 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -1258,7 +1258,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
elif num_channels_transformer != 16:
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 920caf4d24a1..835c0af800da 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -741,7 +741,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -749,7 +749,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1509,7 +1509,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
index 487ad2d80ac6..775d86dcfc54 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
@@ -334,7 +334,7 @@ def check_inputs(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
- raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
index 1c75b5bef933..fa9ba98e6d0d 100644
--- a/src/diffusers/quantizers/base.py
+++ b/src/diffusers/quantizers/base.py
@@ -215,19 +215,15 @@ def _dequantize(self, model):
)
@abstractmethod
- def _process_model_before_weight_loading(self, model, **kwargs):
- ...
+ def _process_model_before_weight_loading(self, model, **kwargs): ...
@abstractmethod
- def _process_model_after_weight_loading(self, model, **kwargs):
- ...
+ def _process_model_after_weight_loading(self, model, **kwargs): ...
@property
@abstractmethod
- def is_serializable(self):
- ...
+ def is_serializable(self): ...
@property
@abstractmethod
- def is_trainable(self):
- ...
+ def is_trainable(self): ...
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 653171638ccf..c946fa1681c0 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -203,8 +203,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index 624d5a5cd4f3..f9eb9c365acd 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -279,8 +279,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 20ad7a4c927d..64195be141f6 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -289,8 +289,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 686b686f6870..2a0cce7bf146 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -413,8 +413,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 5d60383142a4..77770ab2066c 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -431,8 +431,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index c570bac733db..b98c4e33f862 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -241,7 +241,7 @@ def _set_state_dict_into_text_encoder(
"""
text_encoder_state_dict = {
- f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
+ f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
@@ -583,7 +583,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
if self.temp_stored_params is None:
- raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py
index f482deddd2f4..4f001b3047d6 100644
--- a/src/diffusers/utils/deprecation_utils.py
+++ b/src/diffusers/utils/deprecation_utils.py
@@ -40,7 +40,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
line_number = call_frame.lineno
function = call_frame.function
key, value = next(iter(deprecated_kwargs.items()))
- raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
+ raise TypeError(f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`")
if len(values) == 0:
return
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index 6f93450c410c..b96e0e222cb1 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -60,8 +60,7 @@ def _get_default_logging_level() -> int:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
- f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
- f"has to be one of: { ', '.join(log_levels.keys()) }"
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level
diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py
index f23fddd28694..3682c5bfacd6 100644
--- a/src/diffusers/utils/state_dict_utils.py
+++ b/src/diffusers/utils/state_dict_utils.py
@@ -334,7 +334,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index 51e7e640fb02..4ba6f7c25eac 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -882,7 +882,7 @@ def pytest_terminal_summary_main(tr, id):
f.write("slowest durations\n")
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
- f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
+ f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
break
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
@@ -1027,7 +1027,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
process.join(timeout=timeout)
if results["error"] is not None:
- test_case.fail(f'{results["error"]}')
+ test_case.fail(f"{results['error']}")
class CaptureLogger:
diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py
index 74bd43c52315..53aafc9d2eb6 100644
--- a/tests/hooks/test_hooks.py
+++ b/tests/hooks/test_hooks.py
@@ -168,9 +168,7 @@ def test_hook_registry(self):
registry.register_hook(MultiplyHook(2), "multiply_hook")
registry_repr = repr(registry)
- expected_repr = (
- "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
- )
+ expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
@@ -285,12 +283,7 @@ def test_invocation_order_stateful_first(self):
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
- (
- "MultiplyHook pre_forward\n"
- "AddHook pre_forward\n"
- "AddHook post_forward\n"
- "MultiplyHook post_forward\n"
- )
+ ("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n")
.replace(" ", "")
.replace("\n", "")
)
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 6155ac2e39fd..f82a2407f333 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -299,9 +299,9 @@ def test_one_request_upon_cached(self):
)
download_requests = [r.method for r in m.request_history]
- assert (
- download_requests.count("HEAD") == 3
- ), "3 HEAD requests one for config, one for model, and one for shard index file."
+ assert download_requests.count("HEAD") == 3, (
+ "3 HEAD requests one for config, one for model, and one for shard index file."
+ )
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -313,9 +313,9 @@ def test_one_request_upon_cached(self):
)
cache_requests = [r.method for r in m.request_history]
- assert (
- "HEAD" == cache_requests[0] and len(cache_requests) == 2
- ), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
+ assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
+ "We should call only `model_info` to check for commit hash and knowing if shard index is present."
+ )
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index 659d9a82fd76..bfef1fc4f09b 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -92,9 +92,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert (
- model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
- ), "xformers is not enabled"
+ assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
+ "xformers is not enabled"
+ )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
@@ -167,9 +167,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert (
- model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
- ), "xformers is not enabled"
+ assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
+ "xformers is not enabled"
+ )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 8e1187f11468..d01a0b493520 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -654,22 +654,22 @@ def test_model_xattn_mask(self, mask_dtype):
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
- assert full_cond_keepallmask_out.allclose(
- full_cond_out, rtol=1e-05, atol=1e-05
- ), "a 'keep all' mask should give the same result as no mask"
+ assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
+ "a 'keep all' mask should give the same result as no mask"
+ )
trunc_cond = cond[:, :-1, :]
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
- assert not trunc_cond_out.allclose(
- full_cond_out, rtol=1e-05, atol=1e-05
- ), "discarding the last token from our cond should change the result"
+ assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
+ "discarding the last token from our cond should change the result"
+ )
batch, tokens, _ = cond.shape
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
- assert masked_cond_out.allclose(
- trunc_cond_out, rtol=1e-05, atol=1e-05
- ), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
+ assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
+ "masking the last token from our cond should be equivalent to truncating that token out of the condition"
+ )
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -697,9 +697,9 @@ def test_model_xattn_padding(self):
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
- assert trunc_mask_out.allclose(
- keeplast_out
- ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
+ assert trunc_mask_out.allclose(keeplast_out), (
+ "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
+ )
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
@@ -1114,12 +1114,12 @@ def test_load_attn_procs_raise_warning(self):
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
- assert not torch.allclose(
- non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
- ), "LoRA injected UNet should produce different results."
- assert torch.allclose(
- lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
- ), "Loading from a saved checkpoint should produce identical results."
+ assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
+ "LoRA injected UNet should produce different results."
+ )
+ assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
+ "Loading from a saved checkpoint should produce identical results."
+ )
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py
index 3397ca9e394a..071194c59ead 100644
--- a/tests/others/test_image_processor.py
+++ b/tests/others/test_image_processor.py
@@ -65,9 +65,9 @@ def test_vae_image_processor_pt(self):
)
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -78,9 +78,9 @@ def test_vae_image_processor_np(self):
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -93,9 +93,9 @@ def test_vae_image_processor_pil(self):
for i, o in zip(input_pil, out):
in_np = np.array(i)
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_preprocess_input_3d(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
@@ -293,9 +293,9 @@ def test_vae_image_processor_resize_pt(self):
scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale)
- assert (
- out_pt.shape == exp_pt_shape
- ), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
+ assert out_pt.shape == exp_pt_shape, (
+ f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
+ )
def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
@@ -305,6 +305,6 @@ def test_vae_image_processor_resize_np(self):
input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c)
- assert (
- out_np.shape == exp_np_shape
- ), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
+ assert out_np.shape == exp_np_shape, (
+ f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
+ )
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
index a0fbc5df1c28..ac579bbf2be2 100644
--- a/tests/pipelines/amused/test_amused.py
+++ b/tests/pipelines/amused/test_amused.py
@@ -126,8 +126,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
+ def test_inference_batch_single_identical(self): ...
@slow
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
index 2699bbe7f56f..942735f15707 100644
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ b/tests/pipelines/amused/test_amused_img2img.py
@@ -126,8 +126,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
+ def test_inference_batch_single_identical(self): ...
@slow
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
index 645379a7eab1..541b988f1798 100644
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ b/tests/pipelines/amused/test_amused_inpaint.py
@@ -130,8 +130,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
+ def test_inference_batch_single_identical(self): ...
@slow
diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
index c56aeb905ac3..1eb9d1035c33 100644
--- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
+++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
@@ -106,9 +106,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -122,15 +122,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
index e073f55aec9e..db8d36b23a4b 100644
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
@@ -195,9 +195,9 @@ def test_blipdiffusion(self):
[0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
+ )
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index 388dc9ef7ec4..a9de0ff05fe8 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -299,9 +299,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -315,15 +315,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
index 2e962bd247b9..4f32da7ac4ae 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -299,9 +299,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -315,12 +315,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index cac47f1a83d4..ec4e51bd1bad 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -317,9 +317,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -333,15 +333,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
index 4d836cb5e2a4..b1ac8cbd90ed 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
@@ -298,9 +298,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -314,12 +314,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
index eedda4e21722..a5768cb51fbf 100644
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
@@ -219,9 +219,9 @@ def test_blipdiffusion_controlnet(self):
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 9a270c2bbf07..9ce62cde9fe4 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -178,9 +178,9 @@ def test_controlnet_flux(self):
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 59ccb9237819..8d63619c402b 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -170,9 +170,9 @@ def test_fused_qkv_projections(self):
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -186,15 +186,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index f7b3db05c8af..4bd7f59dc0a8 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -162,9 +162,9 @@ def test_controlnet_hunyuandit(self):
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
index 2cd57ce56d52..d9f5dcad7d61 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
@@ -194,9 +194,9 @@ def test_controlnet_inpaint_sd3(self):
[0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 84ce09acbe1a..1be15645efd7 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -202,9 +202,9 @@ def run_pipe(self, components, use_sd35=False):
else:
expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
def test_controlnet_sd3(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py
index 30883ac4a63d..18732c0058de 100644
--- a/tests/pipelines/dit/test_dit.py
+++ b/tests/pipelines/dit/test_dit.py
@@ -149,8 +149,7 @@ def test_dit_512(self):
for word, image in zip(words, images):
expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- f"/dit/{word}_512.npy"
+ f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
)
assert np.abs((expected_image - image).max()) < 1e-1
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index 6a560367a5b8..646ad928ec05 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -170,9 +170,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -186,15 +186,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index d8293952adcb..d8d0774e1e32 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -140,9 +140,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -156,15 +156,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index 44ce2a4dedfc..a2f7c9171082 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -134,9 +134,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -150,15 +150,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/hunyuandit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
index 5b1a82eda227..66453b73b0b3 100644
--- a/tests/pipelines/hunyuandit/test_hunyuan_dit.py
+++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
@@ -21,12 +21,7 @@
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
-from diffusers import (
- AutoencoderKL,
- DDPMScheduler,
- HunyuanDiT2DModel,
- HunyuanDiTPipeline,
-)
+from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -179,9 +174,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -197,15 +192,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@unittest.skip(
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 30144e37a9d4..f4de6f3a5338 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -240,12 +240,12 @@ def test_kandinsky(self):
expected_slice = np.array([1.0000, 1.0000, 0.2766, 1.0000, 0.5447, 0.1737, 1.0000, 0.4316, 0.9024])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index c5f27a9cc9a9..f14a741d7dc1 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -98,12 +98,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.2893, 0.1464, 0.4603, 0.3529, 0.4612, 0.7701, 0.4027, 0.3051, 0.5155])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -206,12 +206,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4852, 0.4136, 0.4539, 0.4781, 0.4680, 0.5217, 0.4973, 0.4089, 0.4977])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -318,12 +318,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index 26361ce18b82..169709978042 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -261,12 +261,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5816, 0.5872, 0.4634, 0.5982, 0.4767, 0.4710, 0.4669, 0.4717, 0.4966])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -321,7 +321,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
@@ -387,7 +387,7 @@ def test_kandinsky_img2img_ddpm(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/frog.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index e30c601b6011..d4d5c4e48f78 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -256,12 +256,12 @@ def test_kandinsky_inpaint(self):
expected_slice = np.array([0.8222, 0.8896, 0.4373, 0.8088, 0.4905, 0.2609, 0.6816, 0.4291, 0.5129])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -319,7 +319,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index fea49d47b7bb..aa17f6fc5d6b 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -210,13 +210,13 @@ def test_kandinsky(self):
expected_slice = np.array([0.3420, 0.9505, 0.3919, 1.0000, 0.5188, 0.3109, 0.6139, 0.5624, 0.6811])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index 90f8b2034109..17ef3dc2601e 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -103,12 +103,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.3076, 0.2729, 0.5668, 0.0522, 0.3384, 0.7028, 0.4908, 0.3659, 0.6243])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -227,12 +227,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4445, 0.4287, 0.4596, 0.3919, 0.3730, 0.5039, 0.4834, 0.4269, 0.5521])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -350,12 +350,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.5039, 0.4926, 0.4898, 0.4978, 0.4838, 0.4942, 0.4738, 0.4702, 0.4816])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index 1f3219e0d69e..10a95d6177b2 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -210,13 +210,13 @@ def test_kandinsky_controlnet(self):
[0.6959826, 0.868279, 0.7558092, 0.68769467, 0.85805804, 0.65977496, 0.44885302, 0.5959111, 0.4251595]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
index 20944aa3d6f8..58fbbecc0569 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
@@ -218,12 +218,12 @@ def test_kandinsky_controlnet_img2img(self):
expected_slice = np.array(
[0.54985034, 0.55509365, 0.52561504, 0.5570494, 0.5593818, 0.5263979, 0.50285643, 0.5069846, 0.51196736]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1.75e-3)
@@ -254,7 +254,7 @@ def test_kandinsky_controlnet_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
init_image = init_image.resize((512, 512))
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index 4702f473a992..aa7589a212eb 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -228,12 +228,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5712, 0.5443, 0.4725, 0.6195, 0.5184, 0.4651, 0.4473, 0.4590, 0.5016])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=2e-1)
@@ -261,7 +261,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index 9a7f659e533c..d7ac69820761 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -234,12 +234,12 @@ def test_kandinsky_inpaint(self):
[0.50775903, 0.49527195, 0.48824543, 0.50192237, 0.48644906, 0.49373814, 0.4780598, 0.47234827, 0.48327848]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -314,7 +314,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index af1d45ff8975..c54b91f024af 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -157,9 +157,9 @@ def test_kandinsky3(self):
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
index e00948621a06..088c32e2860e 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -181,9 +181,9 @@ def test_kandinsky3_img2img(self):
[0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
index 6fa96275406f..b9ce29c70bdf 100644
--- a/tests/pipelines/pag/test_pag_animatediff.py
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -450,9 +450,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py
index ee97b0507a34..02232c7379bd 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd.py
@@ -169,9 +169,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
index 25ef5d253d68..cfc0b218d2e4 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
@@ -165,9 +165,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
index 0588e26286a8..10adff7fe0a6 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
@@ -187,9 +187,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
index 63c7d9fbee2d..fe4b615f646b 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
@@ -189,9 +189,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py
index 31cd9aa666de..d6cfbbed9e95 100644
--- a/tests/pipelines/pag/test_pag_hunyuan_dit.py
+++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py
@@ -177,15 +177,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -198,9 +198,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py
index 9a4f1daa2c05..c9f197b703ef 100644
--- a/tests/pipelines/pag/test_pag_kolors.py
+++ b/tests/pipelines/pag/test_pag_kolors.py
@@ -140,9 +140,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
index 63f42416dbca..624b57844390 100644
--- a/tests/pipelines/pag/test_pag_pixart_sigma.py
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -120,9 +120,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ )
out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py
index a2c657297860..ee1e359383e9 100644
--- a/tests/pipelines/pag/test_pag_sana.py
+++ b/tests/pipelines/pag/test_pag_sana.py
@@ -268,9 +268,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index d4cf00b034ff..bc20226873f6 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -154,9 +154,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -328,9 +328,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -345,6 +345,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py
index 41ff0c3c09f4..737e238e5fbf 100644
--- a/tests/pipelines/pag/test_pag_sd3.py
+++ b/tests/pipelines/pag/test_pag_sd3.py
@@ -170,9 +170,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -186,15 +186,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -207,9 +207,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py
index 2fe988929185..fe593d47dc75 100644
--- a/tests/pipelines/pag/test_pag_sd3_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd3_img2img.py
@@ -149,9 +149,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
@@ -254,9 +254,9 @@ def test_pag_cfg(self):
0.17822266,
]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(
@@ -272,6 +272,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index d000493d6bd1..ef70985571c9 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -161,9 +161,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -267,9 +267,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -285,6 +285,6 @@ def test_pag_uncond(self):
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index 06682c111d37..04ec8b216551 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -302,9 +302,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -319,6 +319,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py
index b35b2b1d2f7e..fc4ce1067f76 100644
--- a/tests/pipelines/pag/test_pag_sdxl.py
+++ b/tests/pipelines/pag/test_pag_sdxl.py
@@ -167,9 +167,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -331,9 +331,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -348,6 +348,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py
index c94a6836de7f..0e5c2cc7f93a 100644
--- a/tests/pipelines/pag/test_pag_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py
@@ -215,9 +215,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -316,9 +316,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -333,6 +333,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
index cca5292288b0..854c65cbc761 100644
--- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
@@ -220,9 +220,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -322,9 +322,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -339,6 +339,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index b220afcfc25a..7084fc9bcec8 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -260,9 +260,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -276,15 +276,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index ac7096874b31..72eee3e35eb1 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -266,7 +266,7 @@ def tearDown(self):
def test_shap_e_img2img(self):
input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/corgi.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
index 1765f3a02242..d433a461bd9d 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
@@ -198,12 +198,12 @@ def test_stable_cascade(self):
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index 6e17b86639ea..3b5c7a24b4ca 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -293,15 +293,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert (
- np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
- ), "ays timesteps and ays sigmas should have the same outputs"
- assert (
- np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
- ), "use ays timesteps should have different outputs"
- assert (
- np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
- ), "use ays sigmas should have different outputs"
+ assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
+ "ays timesteps and ays sigmas should have the same outputs"
+ )
+ assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
+ "use ays timesteps should have different outputs"
+ )
+ assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
+ "use ays sigmas should have different outputs"
+ )
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
@@ -656,9 +656,9 @@ def test_freeu_enabled(self):
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
- assert not np.allclose(
- output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
- ), "Enabling of FreeU should lead to different results."
+ assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
+ "Enabling of FreeU should lead to different results."
+ )
def test_freeu_disabled(self):
components = self.get_dummy_components()
@@ -681,9 +681,9 @@ def test_freeu_disabled(self):
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
).images
- assert np.allclose(
- output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
- ), "Disabling of FreeU should lead to results similar to the default pipeline results."
+ assert np.allclose(output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]), (
+ "Disabling of FreeU should lead to results similar to the default pipeline results."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -706,15 +706,15 @@ def test_fused_qkv_projections(self):
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 38ef6143f4c0..8e2fa77fc083 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -171,9 +171,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -187,15 +187,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_skip_guidance_layers(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index c68cdf67036a..a41e7dc7f342 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -242,15 +242,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert (
- np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
- ), "ays timesteps and ays sigmas should have the same outputs"
- assert (
- np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
- ), "use ays timesteps should have different outputs"
- assert (
- np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
- ), "use ays sigmas should have different outputs"
+ assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
+ "ays timesteps and ays sigmas should have the same outputs"
+ )
+ assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
+ "use ays timesteps should have different outputs"
+ )
+ assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
+ "use ays sigmas should have different outputs"
+ )
def test_ip_adapter(self):
expected_pipe_slice = None
@@ -742,9 +742,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert (
- expected_steps_1 == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps_1 == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
with self.assertRaises(ValueError) as cm:
inputs_2 = {
@@ -771,9 +771,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert (
- expected_steps == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 66ae581a0529..729c6981d2b5 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -585,9 +585,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert (
- expected_steps_1 == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps_1 == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
inputs_2 = {
**inputs,
@@ -601,9 +601,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert (
- expected_steps == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index ae5a12e04ba8..00c7636ed9fd 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -167,9 +167,9 @@ def test_one_request_upon_cached(self):
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 15, "15 calls to files"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
- assert (
- len(download_requests) == 32
- ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
+ assert len(download_requests) == 32, (
+ "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
+ )
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -179,9 +179,9 @@ def test_one_request_upon_cached(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert (
- len(cache_requests) == 2
- ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ assert len(cache_requests) == 2, (
+ "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ )
def test_less_downloads_passed_object(self):
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -217,9 +217,9 @@ def test_less_downloads_passed_object_calls(self):
assert download_requests.count("HEAD") == 13, "13 calls to files"
# 17 - 2 because no call to config or model file for `safety_checker`
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
- assert (
- len(download_requests) == 28
- ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
+ assert len(download_requests) == 28, (
+ "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
+ )
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -229,9 +229,9 @@ def test_less_downloads_passed_object_calls(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert (
- len(cache_requests) == 2
- ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ assert len(cache_requests) == 2, (
+ "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ )
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index d3e39e363f91..a950de142740 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -191,12 +191,12 @@ def test_freeu(self):
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
- assert not np.allclose(
- output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
- ), "Enabling of FreeU should lead to different results."
- assert np.allclose(
- output, output_no_freeu, atol=1e-2
- ), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
+ assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
+ "Enabling of FreeU should lead to different results."
+ )
+ assert np.allclose(output, output_no_freeu, atol=1e-2), (
+ f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -217,12 +217,12 @@ def test_fused_qkv_projections(self):
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
- assert check_qkv_fusion_processors_exist(
- component
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- component, component.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ assert check_qkv_fusion_processors_exist(component), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
+ assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
+ "Something wrong with the attention processors concerning the fused QKV projections."
+ )
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
@@ -235,15 +235,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
class IPAdapterTesterMixin:
@@ -909,9 +909,9 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
- assert all(
- type(proc) == AttnProcessor for proc in component.attn_processors.values()
- ), "`from_pipe` changed the attention processor in original pipeline."
+ assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
+ "`from_pipe` changed the attention processor in original pipeline."
+ )
@require_accelerator
@require_accelerate_version_greater("0.14.0")
@@ -2569,12 +2569,12 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2)
image_slice_pab_disabled = output.flatten()
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
- assert np.allclose(
- original_image_slice, image_slice_pab_enabled, atol=expected_atol
- ), "PAB outputs should not differ much in specified timestep range."
- assert np.allclose(
- original_image_slice, image_slice_pab_disabled, atol=1e-4
- ), "Outputs from normal inference and after disabling cache should not differ."
+ assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), (
+ "PAB outputs should not differ much in specified timestep range."
+ )
+ assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
class FasterCacheTesterMixin:
@@ -2639,12 +2639,12 @@ def run_forward(pipe):
output = run_forward(pipe).flatten()
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
- assert np.allclose(
- original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
- ), "FasterCache outputs should not differ much in specified timestep range."
- assert np.allclose(
- original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
- ), "Outputs from normal inference and after disabling cache should not differ."
+ assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
+ "FasterCache outputs should not differ much in specified timestep range."
+ )
+ assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
def test_faster_cache_state(self):
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
index 084d62a8c613..fa544c91f2d9 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
@@ -191,12 +191,12 @@ def test_wuerstchen(self):
expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py
index 55b3202ad0be..28c354709dc9 100644
--- a/tests/schedulers/test_scheduler_dpm_multi.py
+++ b/tests/schedulers/test_scheduler_dpm_multi.py
@@ -357,9 +357,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py
index 7cbaa5cc5e8d..0756a5ed71ff 100644
--- a/tests/schedulers/test_scheduler_dpm_single.py
+++ b/tests/schedulers/test_scheduler_dpm_single.py
@@ -345,9 +345,9 @@ def test_custom_timesteps(self):
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
index e97d64ec5f1d..8525ce61c40d 100644
--- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
+++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
@@ -188,9 +188,9 @@ def test_solver_order_and_type(self):
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
- assert (
- not torch.isnan(sample).any()
- ), f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
+ assert not torch.isnan(sample).any(), (
+ f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
+ )
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index 4c7e02442cd0..01e173a631cd 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -245,9 +245,9 @@ def test_custom_timesteps(self):
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_custom_sigmas(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
@@ -260,9 +260,9 @@ def test_custom_sigmas(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py
index 9e060c6d476f..90012f5525ab 100644
--- a/tests/schedulers/test_scheduler_heun.py
+++ b/tests/schedulers/test_scheduler_heun.py
@@ -216,9 +216,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 4e7bc0af6842..4e1713c9ceb1 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -72,9 +72,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(
- component, pipe.components[component_name].__class__
- ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ assert isinstance(component, pipe.components[component_name].__class__), (
+ f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ )
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -85,9 +85,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert (
- pipe.components[component_name].config[param_name] == param_value
- ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ assert pipe.components[component_name].config[param_name] == param_value, (
+ f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ )
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -253,9 +253,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(
- component, pipe.components[component_name].__class__
- ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ assert isinstance(component, pipe.components[component_name].__class__), (
+ f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ )
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -266,9 +266,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert (
- pipe.components[component_name].config[param_name] == param_value
- ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ assert pipe.components[component_name].config[param_name] == param_value, (
+ f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ )
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py
index 78e68c4c2df0..d3ffd4fc3a55 100644
--- a/tests/single_file/test_lumina2_transformer.py
+++ b/tests/single_file/test_lumina2_transformer.py
@@ -60,9 +60,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
index b1faeb78776b..31b2eb6e36b0 100644
--- a/tests/single_file/test_model_autoencoder_dc_single_file.py
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -87,9 +87,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
@@ -106,9 +106,9 @@ def test_single_file_in_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_mix_type_variant_components(self):
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
@@ -121,6 +121,6 @@ def test_single_file_mix_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index bfcb802380a6..3580d73531a3 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -58,9 +58,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
index 0ec97db26a9e..bf11faaa9c0e 100644
--- a/tests/single_file/test_model_flux_transformer_single_file.py
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -58,9 +58,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py
index b195f25d094b..a747f16dc1db 100644
--- a/tests/single_file/test_model_motion_adapter_single_file.py
+++ b/tests/single_file/test_model_motion_adapter_single_file.py
@@ -40,9 +40,9 @@ def test_single_file_components_version_v1_5(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_v1_5_2(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt"
@@ -55,9 +55,9 @@ def test_single_file_components_version_v1_5_2(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_v1_5_3(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt"
@@ -70,9 +70,9 @@ def test_single_file_components_version_v1_5_3(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_sdxl_beta(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt"
@@ -85,6 +85,6 @@ def test_single_file_components_version_sdxl_beta(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index 08b04e3cd7e8..92b371c3fb41 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -60,9 +60,9 @@ def test_single_file_components_stage_b(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_b_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -77,9 +77,9 @@ def test_single_file_components_stage_b_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_c(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -94,9 +94,9 @@ def test_single_file_components_stage_c(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_c_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -111,6 +111,6 @@ def test_single_file_components_stage_c_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index 9db4cddb3c9d..bba1726ae380 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -91,9 +91,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py
index f5720ddd3964..7f0e1c1a4b0b 100644
--- a/tests/single_file/test_model_wan_autoencoder_single_file.py
+++ b/tests/single_file/test_model_wan_autoencoder_single_file.py
@@ -56,6 +56,6 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py
index 9b938aa1754c..36f0919cacb5 100644
--- a/tests/single_file/test_model_wan_transformer3d_single_file.py
+++ b/tests/single_file/test_model_wan_transformer3d_single_file.py
@@ -57,9 +57,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
@require_big_gpu_with_torch_cuda
@@ -88,6 +88,6 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py
index 7695e1577711..802ca37abfc3 100644
--- a/tests/single_file/test_sana_transformer.py
+++ b/tests/single_file/test_sana_transformer.py
@@ -47,9 +47,9 @@ def test_single_file_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
diff --git a/utils/log_reports.py b/utils/log_reports.py
index dd1b258519d7..5575c9ba8415 100644
--- a/utils/log_reports.py
+++ b/utils/log_reports.py
@@ -35,7 +35,7 @@ def main(slack_channel_name=None):
if line.get("nodeid", "") != "":
test = line["nodeid"]
if line.get("duration", None) is not None:
- duration = f'{line["duration"]:.4f}'
+ duration = f"{line['duration']:.4f}"
if line.get("outcome", "") == "failed":
section_num_failed += 1
failed.append([test, duration, log.name.split("_")[0]])
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
index a97e65801c5f..4fde581d4170 100644
--- a/utils/update_metadata.py
+++ b/utils/update_metadata.py
@@ -104,8 +104,7 @@ def update_metadata(commit_sha: str):
if commit_sha is not None:
commit_message = (
- f"Update with commit {commit_sha}\n\nSee: "
- f"https://github.com/huggingface/diffusers/commit/{commit_sha}"
+ f"Update with commit {commit_sha}\n\nSee: https://github.com/huggingface/diffusers/commit/{commit_sha}"
)
else:
commit_message = "Update"