Skip to content

Commit 64fe6ef

Browse files
authored
Enable validation of lcm_dreamshaper-v7 with stable_diffusion_evaluator (#3934)
* Enable validation of lcm_dreamshaper-v7 with stable_diffusion_evaluator * Small fixes
1 parent f8f76a2 commit 64fe6ef

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/stable_diffusion_evaluator.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, network_info, launcher, models_args, delayed_model_loading=Fa
6565
def create_pipeline(self, launcher, netowrk_info=None):
6666
tokenizer_config = self.config.get("tokenizer_id", "openai/clip-vit-large-patch14")
6767
tokenizer = AutoTokenizer.from_pretrained(tokenizer_config)
68-
scheduler_config = self.config.get("sheduler_config", {})
68+
scheduler_config = self.config.get("scheduler_config", {})
6969
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
7070
netowrk_info = netowrk_info or self.network_info
7171
self.pipe = OVStableDiffusionPipeline(
@@ -164,7 +164,7 @@ def __init__(
164164
self,
165165
launcher: "BaseLauncher", # noqa: F821
166166
tokenizer: "CLIPTokenizer", # noqa: F821
167-
scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"], # noqa: F821
167+
scheduler: Union["LMSDiscreteScheduler"], # noqa: F821
168168
model_info: Dict,
169169
seed = None,
170170
num_inference_steps = 50
@@ -255,14 +255,24 @@ def __call__(
255255
if accepts_eta:
256256
extra_step_kwargs["eta"] = eta
257257

258+
# lcm_dreamshaper-v7 consist extra unet input
259+
is_extra_input = len(self.unet.inputs) == 4 and self.unet.inputs[3].any_name == 'timestep_cond'
260+
if is_extra_input:
261+
batch_size = len(prompt) if isinstance(prompt, list) else 1
262+
w = torch.tensor(guidance_scale).repeat(batch_size)
263+
w_embedding = self.get_w_embedding(w, embedding_dim=256)
264+
258265
for t in self.progress_bar(timesteps):
259266
# expand the latents if we are doing classifier free guidance
260267
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
261268
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
262269

270+
inputs = [ latent_model_input, np.array(t, dtype=np.float32), text_embeddings ]
271+
if is_extra_input:
272+
inputs.append(w_embedding)
273+
263274
# predict the noise residual
264-
noise_pred = self.unet(
265-
[latent_model_input, np.array(t, dtype=np.float32), text_embeddings])[self._unet_output]
275+
noise_pred = self.unet(inputs)[self._unet_output]
266276
# perform guidance
267277
if do_classifier_free_guidance:
268278
noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
@@ -455,3 +465,27 @@ def print_input_output_info(self):
455465
model = getattr(self, part_model_id, None)
456466
if model is not None:
457467
self.launcher.print_input_output_info(model, part)
468+
469+
@staticmethod
470+
def get_w_embedding(w, embedding_dim=512, dtype=torch.float32):
471+
"""
472+
see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
473+
Args:
474+
timesteps: torch.Tensor: generate embedding vectors at these timesteps
475+
embedding_dim: int: dimension of the embeddings to generate
476+
dtype: data type of the generated embeddings
477+
Returns:
478+
embedding vectors with shape `(len(timesteps), embedding_dim)`
479+
"""
480+
assert len(w.shape) == 1
481+
w = w * 1000.0
482+
483+
half_dim = embedding_dim // 2
484+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
485+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
486+
emb = w.to(dtype)[:, None] * emb[None, :]
487+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
488+
if embedding_dim % 2 == 1: # zero pad
489+
emb = torch.nn.functional.pad(emb, (0, 1))
490+
assert emb.shape == (w.shape[0], embedding_dim)
491+
return emb

0 commit comments

Comments
 (0)