File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed
src/diffusers/models/transformers Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -88,13 +88,18 @@ def main(args):
8888 # y norm
8989 converted_state_dict ["caption_norm.weight" ] = state_dict .pop ("attention_y_norm.weight" )
9090
91+ # scheduler
9192 flow_shift = 3.0
93+
94+ # model config
9295 if args .model_type == "SanaMS_1600M_P1_D20" :
9396 layer_num = 20
9497 elif args .model_type == "SanaMS_600M_P1_D28" :
9598 layer_num = 28
9699 else :
97100 raise ValueError (f"{ args .model_type } is not supported." )
101+ # Positional embedding interpolation scale.
102+ interpolation_scale = {512 : None , 1024 : None , 2048 : 1.0 }
98103
99104 for depth in range (layer_num ):
100105 # Transformer blocks.
@@ -176,6 +181,7 @@ def main(args):
176181 patch_size = 1 ,
177182 norm_elementwise_affine = False ,
178183 norm_eps = 1e-6 ,
184+ interpolation_scale = interpolation_scale [args .image_size ],
179185 )
180186
181187 if is_accelerate_available ():
Original file line number Diff line number Diff line change @@ -242,21 +242,26 @@ def __init__(
242242 patch_size : int = 1 ,
243243 norm_elementwise_affine : bool = False ,
244244 norm_eps : float = 1e-6 ,
245+ interpolation_scale : Optional [int ] = None ,
245246 ) -> None :
246247 super ().__init__ ()
247248
248249 out_channels = out_channels or in_channels
249250 inner_dim = num_attention_heads * attention_head_dim
250251
251252 # 1. Patch Embedding
253+ interpolation_scale = (
254+ interpolation_scale
255+ if interpolation_scale is not None
256+ else max (sample_size // 64 , 1 )
257+ )
252258 self .patch_embed = PatchEmbed (
253259 height = sample_size ,
254260 width = sample_size ,
255261 patch_size = patch_size ,
256262 in_channels = in_channels ,
257263 embed_dim = inner_dim ,
258- interpolation_scale = None ,
259- pos_embed_type = None ,
264+ interpolation_scale = interpolation_scale ,
260265 )
261266
262267 # 2. Additional condition embeddings
You can’t perform that action at this time.
0 commit comments