Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4336,19 +4336,19 @@ The Abstract of the paper:

**64x64**
:-------------------------:
| <img src="https://github.com/user-attachments/assets/9e7bb2cd-45a0-4bd1-adb8-23e283baed39" width="222" height="222" alt="bird_64"> |
| <img src="https://github.com/user-attachments/assets/032738eb-c6cd-4fd9-b4d7-a7317b4b6528" width="222" height="222" alt="bird_64_64"> |

- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:

**64x64** | **256x256**
:-------------------------:|:-------------------------:
| <img src="https://github.com/user-attachments/assets/6b724c2e-5e6a-4b63-9b65-c1182cbb67e0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/7dbab2ad-bf40-4a73-ab04-f178347cb7d5" width="222" height="222" alt="256x256"> |
| <img src="https://github.com/user-attachments/assets/21b9ad8b-eea6-4603-80a2-31180f391589" width="222" height="222" alt="bird_256_64"> | <img src="https://github.com/user-attachments/assets/fc411682-8a36-422c-9488-395b77d4406e" width="222" height="222" alt="bird_256_256"> |

- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps:
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible in this context! With `250` DDIM inference steps:

**64x64** | **256x256** | **1024x1024**
:-------------------------:|:-------------------------:|:-------------------------:
| <img src="https://github.com/user-attachments/assets/4a9454e4-e20a-4736-a196-270e2ae796c0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/4a96555d-0fda-4303-82b1-a4d886f770b9" width="222" height="222" alt="256x256"> | <img src="https://github.com/user-attachments/assets/e0239b7a-ab73-4d45-8f3e-b4e6b4b50abe" width="222" height="222" alt="1024x1024"> |
| <img src="https://github.com/user-attachments/assets/febf4b98-3dee-4a8e-9946-fd42e1f232e6" width="222" height="222" alt="bird_1024_64"> | <img src="https://github.com/user-attachments/assets/c5f85b40-5d6d-4267-a92a-c89dff015b9b" width="222" height="222" alt="bird_1024_256"> | <img src="https://github.com/user-attachments/assets/ad66b913-4367-4cb9-889e-bc06f4d96148" width="222" height="222" alt="bird_1024_1024"> |

```py
from diffusers import DiffusionPipeline
Expand All @@ -4362,8 +4362,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-model

prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
image = pipe(prompt, num_inference_steps=50).images
make_image_grid(image, rows=1, cols=len(image))

# pipe.change_nesting_level(<int>) # 0, 1, or 2
Expand Down
28 changes: 16 additions & 12 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@

>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
>>> custom_pipeline="matryoshka").to("cuda")
... nesting_level=0,
... trust_remote_code=False, # One needs to give permission for this code to run
... ).to("cuda")

>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
>>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
>>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
>>> image = pipe(prompt, num_inference_steps=50).images
>>> make_image_grid(image, rows=1, cols=len(image))

>>> pipe.change_nesting_level(<int>) # 0, 1, or 2
>>> # pipe.change_nesting_level(<int>) # 0, 1, or 2
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
"""
Expand Down Expand Up @@ -420,6 +421,7 @@ def __init__(
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))

self.scales = None
self.schedule_shifted_power = 1.0

def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Expand Down Expand Up @@ -532,6 +534,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

def get_schedule_shifted(self, alpha_prod, scale_factor=None):
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
scale_factor = scale_factor**self.schedule_shifted_power
snr = alpha_prod / (1 - alpha_prod)
scaled_snr = snr / scale_factor
alpha_prod = 1 / (1 + 1 / scaled_snr)
Expand Down Expand Up @@ -639,17 +642,14 @@ def step(
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
if len(model_output) > 1:
pred_original_sample = [
self._threshold_sample(p_o_s * scale) / scale
for p_o_s, scale in zip(pred_original_sample, self.scales)
]
pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]
else:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
if len(model_output) > 1:
pred_original_sample = [
(p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
for p_o_s, scale in zip(pred_original_sample, self.scales)
p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
for p_o_s in pred_original_sample
]
else:
pred_original_sample = pred_original_sample.clamp(
Expand Down Expand Up @@ -3816,6 +3816,8 @@ def __init__(

if hasattr(unet, "nest_ratio"):
scheduler.scales = unet.nest_ratio + [1]
if nesting_level == 2:
scheduler.schedule_shifted_power = 2.0

self.register_modules(
text_encoder=text_encoder,
Expand All @@ -3842,12 +3844,14 @@ def change_nesting_level(self, nesting_level: int):
).to(self.device)
self.config.nesting_level = 1
self.scheduler.scales = self.unet.nest_ratio + [1]
self.scheduler.schedule_shifted_power = 1.0
elif nesting_level == 2:
self.unet = NestedUNet2DConditionModel.from_pretrained(
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
).to(self.device)
self.config.nesting_level = 2
self.scheduler.scales = self.unet.nest_ratio + [1]
self.scheduler.schedule_shifted_power = 2.0
else:
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")

Expand Down Expand Up @@ -4627,8 +4631,8 @@ def __call__(
image = latents

if self.scheduler.scales is not None:
for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)):
image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
for i, img in enumerate(image):
image[i] = self.image_processor.postprocess(img, output_type=output_type)[0]
else:
image = self.image_processor.postprocess(image, output_type=output_type)

Expand Down
Loading