@@ -65,7 +65,7 @@ def __init__(self, network_info, launcher, models_args, delayed_model_loading=Fa
65
65
def create_pipeline (self , launcher , netowrk_info = None ):
66
66
tokenizer_config = self .config .get ("tokenizer_id" , "openai/clip-vit-large-patch14" )
67
67
tokenizer = AutoTokenizer .from_pretrained (tokenizer_config )
68
- scheduler_config = self .config .get ("sheduler_config " , {})
68
+ scheduler_config = self .config .get ("scheduler_config " , {})
69
69
scheduler = LMSDiscreteScheduler .from_config (scheduler_config )
70
70
netowrk_info = netowrk_info or self .network_info
71
71
self .pipe = OVStableDiffusionPipeline (
@@ -164,7 +164,7 @@ def __init__(
164
164
self ,
165
165
launcher : "BaseLauncher" , # noqa: F821
166
166
tokenizer : "CLIPTokenizer" , # noqa: F821
167
- scheduler : Union ["DDIMScheduler" , "PNDMScheduler" , " LMSDiscreteScheduler" ], # noqa: F821
167
+ scheduler : Union ["LMSDiscreteScheduler" ], # noqa: F821
168
168
model_info : Dict ,
169
169
seed = None ,
170
170
num_inference_steps = 50
@@ -255,14 +255,24 @@ def __call__(
255
255
if accepts_eta :
256
256
extra_step_kwargs ["eta" ] = eta
257
257
258
+ # lcm_dreamshaper-v7 consist extra unet input
259
+ is_extra_input = len (self .unet .inputs ) == 4 and self .unet .inputs [3 ].any_name == 'timestep_cond'
260
+ if is_extra_input :
261
+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
262
+ w = torch .tensor (guidance_scale ).repeat (batch_size )
263
+ w_embedding = self .get_w_embedding (w , embedding_dim = 256 )
264
+
258
265
for t in self .progress_bar (timesteps ):
259
266
# expand the latents if we are doing classifier free guidance
260
267
latent_model_input = np .concatenate ([latents ] * 2 ) if do_classifier_free_guidance else latents
261
268
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
262
269
270
+ inputs = [ latent_model_input , np .array (t , dtype = np .float32 ), text_embeddings ]
271
+ if is_extra_input :
272
+ inputs .append (w_embedding )
273
+
263
274
# predict the noise residual
264
- noise_pred = self .unet (
265
- [latent_model_input , np .array (t , dtype = np .float32 ), text_embeddings ])[self ._unet_output ]
275
+ noise_pred = self .unet (inputs )[self ._unet_output ]
266
276
# perform guidance
267
277
if do_classifier_free_guidance :
268
278
noise_pred_uncond , noise_pred_text = noise_pred [0 ], noise_pred [1 ]
@@ -455,3 +465,27 @@ def print_input_output_info(self):
455
465
model = getattr (self , part_model_id , None )
456
466
if model is not None :
457
467
self .launcher .print_input_output_info (model , part )
468
+
469
+ @staticmethod
470
+ def get_w_embedding (w , embedding_dim = 512 , dtype = torch .float32 ):
471
+ """
472
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
473
+ Args:
474
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
475
+ embedding_dim: int: dimension of the embeddings to generate
476
+ dtype: data type of the generated embeddings
477
+ Returns:
478
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
479
+ """
480
+ assert len (w .shape ) == 1
481
+ w = w * 1000.0
482
+
483
+ half_dim = embedding_dim // 2
484
+ emb = torch .log (torch .tensor (10000.0 )) / (half_dim - 1 )
485
+ emb = torch .exp (torch .arange (half_dim , dtype = dtype ) * - emb )
486
+ emb = w .to (dtype )[:, None ] * emb [None , :]
487
+ emb = torch .cat ([torch .sin (emb ), torch .cos (emb )], dim = 1 )
488
+ if embedding_dim % 2 == 1 : # zero pad
489
+ emb = torch .nn .functional .pad (emb , (0 , 1 ))
490
+ assert emb .shape == (w .shape [0 ], embedding_dim )
491
+ return emb
0 commit comments