Skip to content

Commit f702292

Browse files
committed
move the __call__ func to the end of file
1 parent dd9dfd8 commit f702292

File tree

1 file changed

+178
-178
lines changed

1 file changed

+178
-178
lines changed

src/diffusers/pipelines/bria/pipeline_bria.py

Lines changed: 178 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,184 @@ def num_timesteps(self):
267267
def interrupt(self):
268268
return self._interrupt
269269

270+
def check_inputs(
271+
self,
272+
prompt,
273+
height,
274+
width,
275+
negative_prompt=None,
276+
prompt_embeds=None,
277+
negative_prompt_embeds=None,
278+
callback_on_step_end_tensor_inputs=None,
279+
max_sequence_length=None,
280+
):
281+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
282+
logger.warning(
283+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
284+
)
285+
if callback_on_step_end_tensor_inputs is not None and not all(
286+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
287+
):
288+
raise ValueError(
289+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
290+
)
291+
292+
if prompt is not None and prompt_embeds is not None:
293+
raise ValueError(
294+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
295+
" only forward one of the two."
296+
)
297+
elif prompt is None and prompt_embeds is None:
298+
raise ValueError(
299+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
300+
)
301+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
302+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
303+
304+
if negative_prompt is not None and negative_prompt_embeds is not None:
305+
raise ValueError(
306+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
307+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
308+
)
309+
310+
if prompt_embeds is not None and negative_prompt_embeds is not None:
311+
if prompt_embeds.shape != negative_prompt_embeds.shape:
312+
raise ValueError(
313+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
314+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
315+
f" {negative_prompt_embeds.shape}."
316+
)
317+
318+
if max_sequence_length is not None and max_sequence_length > 512:
319+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
320+
321+
def _get_t5_prompt_embeds(
322+
self,
323+
prompt: Union[str, List[str]] = None,
324+
num_images_per_prompt: int = 1,
325+
max_sequence_length: int = 128,
326+
device: Optional[torch.device] = None,
327+
):
328+
tokenizer = self.tokenizer
329+
text_encoder = self.text_encoder
330+
device = device or text_encoder.device
331+
332+
prompt = [prompt] if isinstance(prompt, str) else prompt
333+
batch_size = len(prompt)
334+
prompt_embeds_list = []
335+
for p in prompt:
336+
text_inputs = tokenizer(
337+
p,
338+
# padding="max_length",
339+
max_length=max_sequence_length,
340+
truncation=True,
341+
add_special_tokens=True,
342+
return_tensors="pt",
343+
)
344+
text_input_ids = text_inputs.input_ids
345+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
346+
347+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
348+
text_input_ids, untruncated_ids
349+
):
350+
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
351+
logger.warning(
352+
"The following part of your input was truncated because `max_sequence_length` is set to "
353+
f" {max_sequence_length} tokens: {removed_text}"
354+
)
355+
356+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
357+
358+
# Concat zeros to max_sequence
359+
b, seq_len, dim = prompt_embeds.shape
360+
if seq_len < max_sequence_length:
361+
padding = torch.zeros(
362+
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
363+
)
364+
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
365+
prompt_embeds_list.append(prompt_embeds)
366+
367+
prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
368+
prompt_embeds = prompt_embeds.to(device=device)
369+
370+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
371+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
372+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
373+
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
374+
return prompt_embeds
375+
376+
def prepare_latents(
377+
self,
378+
batch_size,
379+
num_channels_latents,
380+
height,
381+
width,
382+
dtype,
383+
device,
384+
generator,
385+
latents=None,
386+
):
387+
# VAE applies 8x compression on images but we must also account for packing which requires
388+
# latent height and width to be divisible by 2.
389+
height = 2 * (int(height) // self.vae_scale_factor)
390+
width = 2 * (int(width) // self.vae_scale_factor)
391+
392+
shape = (batch_size, num_channels_latents, height, width)
393+
394+
if latents is not None:
395+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
396+
return latents.to(device=device, dtype=dtype), latent_image_ids
397+
398+
if isinstance(generator, list) and len(generator) != batch_size:
399+
raise ValueError(
400+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
401+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
402+
)
403+
404+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
405+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
406+
407+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
408+
409+
return latents, latent_image_ids
410+
411+
@staticmethod
412+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
413+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
414+
latents = latents.permute(0, 2, 4, 1, 3, 5)
415+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
416+
417+
return latents
418+
419+
@staticmethod
420+
def _unpack_latents(latents, height, width, vae_scale_factor):
421+
batch_size, num_patches, channels = latents.shape
422+
423+
height = height // vae_scale_factor
424+
width = width // vae_scale_factor
425+
426+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
427+
latents = latents.permute(0, 3, 1, 4, 2, 5)
428+
429+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
430+
431+
return latents
432+
433+
@staticmethod
434+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
435+
latent_image_ids = torch.zeros(height, width, 3)
436+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
437+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
438+
439+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
440+
441+
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
442+
latent_image_ids = latent_image_ids.reshape(
443+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
444+
)
445+
446+
return latent_image_ids.to(device=device, dtype=dtype)
447+
270448
@torch.no_grad()
271449
@replace_example_docstring(EXAMPLE_DOC_STRING)
272450
def __call__(
@@ -549,181 +727,3 @@ def __call__(
549727
return (image,)
550728

551729
return BriaPipelineOutput(images=image)
552-
553-
def check_inputs(
554-
self,
555-
prompt,
556-
height,
557-
width,
558-
negative_prompt=None,
559-
prompt_embeds=None,
560-
negative_prompt_embeds=None,
561-
callback_on_step_end_tensor_inputs=None,
562-
max_sequence_length=None,
563-
):
564-
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
565-
logger.warning(
566-
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
567-
)
568-
if callback_on_step_end_tensor_inputs is not None and not all(
569-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
570-
):
571-
raise ValueError(
572-
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
573-
)
574-
575-
if prompt is not None and prompt_embeds is not None:
576-
raise ValueError(
577-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
578-
" only forward one of the two."
579-
)
580-
elif prompt is None and prompt_embeds is None:
581-
raise ValueError(
582-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
583-
)
584-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
585-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
586-
587-
if negative_prompt is not None and negative_prompt_embeds is not None:
588-
raise ValueError(
589-
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
590-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
591-
)
592-
593-
if prompt_embeds is not None and negative_prompt_embeds is not None:
594-
if prompt_embeds.shape != negative_prompt_embeds.shape:
595-
raise ValueError(
596-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
597-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
598-
f" {negative_prompt_embeds.shape}."
599-
)
600-
601-
if max_sequence_length is not None and max_sequence_length > 512:
602-
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
603-
604-
def _get_t5_prompt_embeds(
605-
self,
606-
prompt: Union[str, List[str]] = None,
607-
num_images_per_prompt: int = 1,
608-
max_sequence_length: int = 128,
609-
device: Optional[torch.device] = None,
610-
):
611-
tokenizer = self.tokenizer
612-
text_encoder = self.text_encoder
613-
device = device or text_encoder.device
614-
615-
prompt = [prompt] if isinstance(prompt, str) else prompt
616-
batch_size = len(prompt)
617-
prompt_embeds_list = []
618-
for p in prompt:
619-
text_inputs = tokenizer(
620-
p,
621-
# padding="max_length",
622-
max_length=max_sequence_length,
623-
truncation=True,
624-
add_special_tokens=True,
625-
return_tensors="pt",
626-
)
627-
text_input_ids = text_inputs.input_ids
628-
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
629-
630-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
631-
text_input_ids, untruncated_ids
632-
):
633-
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
634-
logger.warning(
635-
"The following part of your input was truncated because `max_sequence_length` is set to "
636-
f" {max_sequence_length} tokens: {removed_text}"
637-
)
638-
639-
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
640-
641-
# Concat zeros to max_sequence
642-
b, seq_len, dim = prompt_embeds.shape
643-
if seq_len < max_sequence_length:
644-
padding = torch.zeros(
645-
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
646-
)
647-
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
648-
prompt_embeds_list.append(prompt_embeds)
649-
650-
prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
651-
prompt_embeds = prompt_embeds.to(device=device)
652-
653-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
654-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
655-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
656-
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
657-
return prompt_embeds
658-
659-
def prepare_latents(
660-
self,
661-
batch_size,
662-
num_channels_latents,
663-
height,
664-
width,
665-
dtype,
666-
device,
667-
generator,
668-
latents=None,
669-
):
670-
# VAE applies 8x compression on images but we must also account for packing which requires
671-
# latent height and width to be divisible by 2.
672-
height = 2 * (int(height) // self.vae_scale_factor)
673-
width = 2 * (int(width) // self.vae_scale_factor)
674-
675-
shape = (batch_size, num_channels_latents, height, width)
676-
677-
if latents is not None:
678-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
679-
return latents.to(device=device, dtype=dtype), latent_image_ids
680-
681-
if isinstance(generator, list) and len(generator) != batch_size:
682-
raise ValueError(
683-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
684-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
685-
)
686-
687-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
688-
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
689-
690-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
691-
692-
return latents, latent_image_ids
693-
694-
@staticmethod
695-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
696-
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
697-
latents = latents.permute(0, 2, 4, 1, 3, 5)
698-
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
699-
700-
return latents
701-
702-
@staticmethod
703-
def _unpack_latents(latents, height, width, vae_scale_factor):
704-
batch_size, num_patches, channels = latents.shape
705-
706-
height = height // vae_scale_factor
707-
width = width // vae_scale_factor
708-
709-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
710-
latents = latents.permute(0, 3, 1, 4, 2, 5)
711-
712-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
713-
714-
return latents
715-
716-
@staticmethod
717-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
718-
latent_image_ids = torch.zeros(height, width, 3)
719-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
720-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
721-
722-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
723-
724-
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
725-
latent_image_ids = latent_image_ids.reshape(
726-
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
727-
)
728-
729-
return latent_image_ids.to(device=device, dtype=dtype)

0 commit comments

Comments
 (0)