107
107
108
108
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
109
109
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
110
- >>> custom_pipeline="matryoshka").to("cuda")
110
+ ... nesting_level=0,
111
+ ... trust_remote_code=False, # One needs to give permission for this code to run
112
+ ... ).to("cuda")
111
113
112
114
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
113
115
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
114
- >>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
115
- >>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
116
+ >>> image = pipe(prompt, num_inference_steps=50).images
116
117
>>> make_image_grid(image, rows=1, cols=len(image))
117
118
118
- >>> pipe.change_nesting_level(<int>) # 0, 1, or 2
119
+ >>> # pipe.change_nesting_level(<int>) # 0, 1, or 2
119
120
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
120
121
```
121
122
"""
@@ -420,6 +421,7 @@ def __init__(
420
421
self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ().astype (np .int64 ))
421
422
422
423
self .scales = None
424
+ self .schedule_shifted_power = 1.0
423
425
424
426
def scale_model_input (self , sample : torch .Tensor , timestep : Optional [int ] = None ) -> torch .Tensor :
425
427
"""
@@ -532,6 +534,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
532
534
533
535
def get_schedule_shifted (self , alpha_prod , scale_factor = None ):
534
536
if (scale_factor is not None ) and (scale_factor > 1 ): # rescale noise schedule
537
+ scale_factor = scale_factor ** self .schedule_shifted_power
535
538
snr = alpha_prod / (1 - alpha_prod )
536
539
scaled_snr = snr / scale_factor
537
540
alpha_prod = 1 / (1 + 1 / scaled_snr )
@@ -639,17 +642,14 @@ def step(
639
642
# 4. Clip or threshold "predicted x_0"
640
643
if self .config .thresholding :
641
644
if len (model_output ) > 1 :
642
- pred_original_sample = [
643
- self ._threshold_sample (p_o_s * scale ) / scale
644
- for p_o_s , scale in zip (pred_original_sample , self .scales )
645
- ]
645
+ pred_original_sample = [self ._threshold_sample (p_o_s ) for p_o_s in pred_original_sample ]
646
646
else :
647
647
pred_original_sample = self ._threshold_sample (pred_original_sample )
648
648
elif self .config .clip_sample :
649
649
if len (model_output ) > 1 :
650
650
pred_original_sample = [
651
- ( p_o_s * scale ) .clamp (- self .config .clip_sample_range , self .config .clip_sample_range ) / scale
652
- for p_o_s , scale in zip ( pred_original_sample , self . scales )
651
+ p_o_s .clamp (- self .config .clip_sample_range , self .config .clip_sample_range )
652
+ for p_o_s in pred_original_sample
653
653
]
654
654
else :
655
655
pred_original_sample = pred_original_sample .clamp (
@@ -3816,6 +3816,8 @@ def __init__(
3816
3816
3817
3817
if hasattr (unet , "nest_ratio" ):
3818
3818
scheduler .scales = unet .nest_ratio + [1 ]
3819
+ if nesting_level == 2 :
3820
+ scheduler .schedule_shifted_power = 2.0
3819
3821
3820
3822
self .register_modules (
3821
3823
text_encoder = text_encoder ,
@@ -3842,12 +3844,14 @@ def change_nesting_level(self, nesting_level: int):
3842
3844
).to (self .device )
3843
3845
self .config .nesting_level = 1
3844
3846
self .scheduler .scales = self .unet .nest_ratio + [1 ]
3847
+ self .scheduler .schedule_shifted_power = 1.0
3845
3848
elif nesting_level == 2 :
3846
3849
self .unet = NestedUNet2DConditionModel .from_pretrained (
3847
3850
"tolgacangoz/matryoshka-diffusion-models" , subfolder = "unet/nesting_level_2"
3848
3851
).to (self .device )
3849
3852
self .config .nesting_level = 2
3850
3853
self .scheduler .scales = self .unet .nest_ratio + [1 ]
3854
+ self .scheduler .schedule_shifted_power = 2.0
3851
3855
else :
3852
3856
raise ValueError ("Currently, nesting levels 0, 1, and 2 are supported." )
3853
3857
@@ -4627,8 +4631,8 @@ def __call__(
4627
4631
image = latents
4628
4632
4629
4633
if self .scheduler .scales is not None :
4630
- for i , ( img , scale ) in enumerate (zip ( image , self . scheduler . scales ) ):
4631
- image [i ] = self .image_processor .postprocess (img * scale , output_type = output_type )[0 ]
4634
+ for i , img in enumerate (image ):
4635
+ image [i ] = self .image_processor .postprocess (img , output_type = output_type )[0 ]
4632
4636
else :
4633
4637
image = self .image_processor .postprocess (image , output_type = output_type )
4634
4638
0 commit comments