@@ -267,6 +267,184 @@ def num_timesteps(self):
267267 def interrupt (self ):
268268 return self ._interrupt
269269
270+ def check_inputs (
271+ self ,
272+ prompt ,
273+ height ,
274+ width ,
275+ negative_prompt = None ,
276+ prompt_embeds = None ,
277+ negative_prompt_embeds = None ,
278+ callback_on_step_end_tensor_inputs = None ,
279+ max_sequence_length = None ,
280+ ):
281+ if height % (self .vae_scale_factor * 2 ) != 0 or width % (self .vae_scale_factor * 2 ) != 0 :
282+ logger .warning (
283+ f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } . Dimensions will be resized accordingly"
284+ )
285+ if callback_on_step_end_tensor_inputs is not None and not all (
286+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
287+ ):
288+ raise ValueError (
289+ f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
290+ )
291+
292+ if prompt is not None and prompt_embeds is not None :
293+ raise ValueError (
294+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
295+ " only forward one of the two."
296+ )
297+ elif prompt is None and prompt_embeds is None :
298+ raise ValueError (
299+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
300+ )
301+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
302+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
303+
304+ if negative_prompt is not None and negative_prompt_embeds is not None :
305+ raise ValueError (
306+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
307+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
308+ )
309+
310+ if prompt_embeds is not None and negative_prompt_embeds is not None :
311+ if prompt_embeds .shape != negative_prompt_embeds .shape :
312+ raise ValueError (
313+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
314+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
315+ f" { negative_prompt_embeds .shape } ."
316+ )
317+
318+ if max_sequence_length is not None and max_sequence_length > 512 :
319+ raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
320+
321+ def _get_t5_prompt_embeds (
322+ self ,
323+ prompt : Union [str , List [str ]] = None ,
324+ num_images_per_prompt : int = 1 ,
325+ max_sequence_length : int = 128 ,
326+ device : Optional [torch .device ] = None ,
327+ ):
328+ tokenizer = self .tokenizer
329+ text_encoder = self .text_encoder
330+ device = device or text_encoder .device
331+
332+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
333+ batch_size = len (prompt )
334+ prompt_embeds_list = []
335+ for p in prompt :
336+ text_inputs = tokenizer (
337+ p ,
338+ # padding="max_length",
339+ max_length = max_sequence_length ,
340+ truncation = True ,
341+ add_special_tokens = True ,
342+ return_tensors = "pt" ,
343+ )
344+ text_input_ids = text_inputs .input_ids
345+ untruncated_ids = tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
346+
347+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
348+ text_input_ids , untruncated_ids
349+ ):
350+ removed_text = tokenizer .batch_decode (untruncated_ids [:, max_sequence_length - 1 : - 1 ])
351+ logger .warning (
352+ "The following part of your input was truncated because `max_sequence_length` is set to "
353+ f" { max_sequence_length } tokens: { removed_text } "
354+ )
355+
356+ prompt_embeds = text_encoder (text_input_ids .to (device ))[0 ]
357+
358+ # Concat zeros to max_sequence
359+ b , seq_len , dim = prompt_embeds .shape
360+ if seq_len < max_sequence_length :
361+ padding = torch .zeros (
362+ (b , max_sequence_length - seq_len , dim ), dtype = prompt_embeds .dtype , device = prompt_embeds .device
363+ )
364+ prompt_embeds = torch .concat ([prompt_embeds , padding ], dim = 1 )
365+ prompt_embeds_list .append (prompt_embeds )
366+
367+ prompt_embeds = torch .concat (prompt_embeds_list , dim = 0 )
368+ prompt_embeds = prompt_embeds .to (device = device )
369+
370+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
371+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
372+ prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , max_sequence_length , - 1 )
373+ prompt_embeds = prompt_embeds .to (dtype = self .transformer .dtype )
374+ return prompt_embeds
375+
376+ def prepare_latents (
377+ self ,
378+ batch_size ,
379+ num_channels_latents ,
380+ height ,
381+ width ,
382+ dtype ,
383+ device ,
384+ generator ,
385+ latents = None ,
386+ ):
387+ # VAE applies 8x compression on images but we must also account for packing which requires
388+ # latent height and width to be divisible by 2.
389+ height = 2 * (int (height ) // self .vae_scale_factor )
390+ width = 2 * (int (width ) // self .vae_scale_factor )
391+
392+ shape = (batch_size , num_channels_latents , height , width )
393+
394+ if latents is not None :
395+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
396+ return latents .to (device = device , dtype = dtype ), latent_image_ids
397+
398+ if isinstance (generator , list ) and len (generator ) != batch_size :
399+ raise ValueError (
400+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
401+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
402+ )
403+
404+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
405+ latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
406+
407+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
408+
409+ return latents , latent_image_ids
410+
411+ @staticmethod
412+ def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
413+ latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
414+ latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
415+ latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
416+
417+ return latents
418+
419+ @staticmethod
420+ def _unpack_latents (latents , height , width , vae_scale_factor ):
421+ batch_size , num_patches , channels = latents .shape
422+
423+ height = height // vae_scale_factor
424+ width = width // vae_scale_factor
425+
426+ latents = latents .view (batch_size , height , width , channels // 4 , 2 , 2 )
427+ latents = latents .permute (0 , 3 , 1 , 4 , 2 , 5 )
428+
429+ latents = latents .reshape (batch_size , channels // (2 * 2 ), height * 2 , width * 2 )
430+
431+ return latents
432+
433+ @staticmethod
434+ def _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
435+ latent_image_ids = torch .zeros (height , width , 3 )
436+ latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
437+ latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
438+
439+ latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
440+
441+ latent_image_ids = latent_image_ids .repeat (batch_size , 1 , 1 , 1 )
442+ latent_image_ids = latent_image_ids .reshape (
443+ batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
444+ )
445+
446+ return latent_image_ids .to (device = device , dtype = dtype )
447+
270448 @torch .no_grad ()
271449 @replace_example_docstring (EXAMPLE_DOC_STRING )
272450 def __call__ (
@@ -549,181 +727,3 @@ def __call__(
549727 return (image ,)
550728
551729 return BriaPipelineOutput (images = image )
552-
553- def check_inputs (
554- self ,
555- prompt ,
556- height ,
557- width ,
558- negative_prompt = None ,
559- prompt_embeds = None ,
560- negative_prompt_embeds = None ,
561- callback_on_step_end_tensor_inputs = None ,
562- max_sequence_length = None ,
563- ):
564- if height % (self .vae_scale_factor * 2 ) != 0 or width % (self .vae_scale_factor * 2 ) != 0 :
565- logger .warning (
566- f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } . Dimensions will be resized accordingly"
567- )
568- if callback_on_step_end_tensor_inputs is not None and not all (
569- k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
570- ):
571- raise ValueError (
572- f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
573- )
574-
575- if prompt is not None and prompt_embeds is not None :
576- raise ValueError (
577- f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
578- " only forward one of the two."
579- )
580- elif prompt is None and prompt_embeds is None :
581- raise ValueError (
582- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
583- )
584- elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
585- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
586-
587- if negative_prompt is not None and negative_prompt_embeds is not None :
588- raise ValueError (
589- f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
590- f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
591- )
592-
593- if prompt_embeds is not None and negative_prompt_embeds is not None :
594- if prompt_embeds .shape != negative_prompt_embeds .shape :
595- raise ValueError (
596- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
597- f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
598- f" { negative_prompt_embeds .shape } ."
599- )
600-
601- if max_sequence_length is not None and max_sequence_length > 512 :
602- raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
603-
604- def _get_t5_prompt_embeds (
605- self ,
606- prompt : Union [str , List [str ]] = None ,
607- num_images_per_prompt : int = 1 ,
608- max_sequence_length : int = 128 ,
609- device : Optional [torch .device ] = None ,
610- ):
611- tokenizer = self .tokenizer
612- text_encoder = self .text_encoder
613- device = device or text_encoder .device
614-
615- prompt = [prompt ] if isinstance (prompt , str ) else prompt
616- batch_size = len (prompt )
617- prompt_embeds_list = []
618- for p in prompt :
619- text_inputs = tokenizer (
620- p ,
621- # padding="max_length",
622- max_length = max_sequence_length ,
623- truncation = True ,
624- add_special_tokens = True ,
625- return_tensors = "pt" ,
626- )
627- text_input_ids = text_inputs .input_ids
628- untruncated_ids = tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
629-
630- if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
631- text_input_ids , untruncated_ids
632- ):
633- removed_text = tokenizer .batch_decode (untruncated_ids [:, max_sequence_length - 1 : - 1 ])
634- logger .warning (
635- "The following part of your input was truncated because `max_sequence_length` is set to "
636- f" { max_sequence_length } tokens: { removed_text } "
637- )
638-
639- prompt_embeds = text_encoder (text_input_ids .to (device ))[0 ]
640-
641- # Concat zeros to max_sequence
642- b , seq_len , dim = prompt_embeds .shape
643- if seq_len < max_sequence_length :
644- padding = torch .zeros (
645- (b , max_sequence_length - seq_len , dim ), dtype = prompt_embeds .dtype , device = prompt_embeds .device
646- )
647- prompt_embeds = torch .concat ([prompt_embeds , padding ], dim = 1 )
648- prompt_embeds_list .append (prompt_embeds )
649-
650- prompt_embeds = torch .concat (prompt_embeds_list , dim = 0 )
651- prompt_embeds = prompt_embeds .to (device = device )
652-
653- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
654- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
655- prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , max_sequence_length , - 1 )
656- prompt_embeds = prompt_embeds .to (dtype = self .transformer .dtype )
657- return prompt_embeds
658-
659- def prepare_latents (
660- self ,
661- batch_size ,
662- num_channels_latents ,
663- height ,
664- width ,
665- dtype ,
666- device ,
667- generator ,
668- latents = None ,
669- ):
670- # VAE applies 8x compression on images but we must also account for packing which requires
671- # latent height and width to be divisible by 2.
672- height = 2 * (int (height ) // self .vae_scale_factor )
673- width = 2 * (int (width ) // self .vae_scale_factor )
674-
675- shape = (batch_size , num_channels_latents , height , width )
676-
677- if latents is not None :
678- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
679- return latents .to (device = device , dtype = dtype ), latent_image_ids
680-
681- if isinstance (generator , list ) and len (generator ) != batch_size :
682- raise ValueError (
683- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
684- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
685- )
686-
687- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
688- latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
689-
690- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
691-
692- return latents , latent_image_ids
693-
694- @staticmethod
695- def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
696- latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
697- latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
698- latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
699-
700- return latents
701-
702- @staticmethod
703- def _unpack_latents (latents , height , width , vae_scale_factor ):
704- batch_size , num_patches , channels = latents .shape
705-
706- height = height // vae_scale_factor
707- width = width // vae_scale_factor
708-
709- latents = latents .view (batch_size , height , width , channels // 4 , 2 , 2 )
710- latents = latents .permute (0 , 3 , 1 , 4 , 2 , 5 )
711-
712- latents = latents .reshape (batch_size , channels // (2 * 2 ), height * 2 , width * 2 )
713-
714- return latents
715-
716- @staticmethod
717- def _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
718- latent_image_ids = torch .zeros (height , width , 3 )
719- latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
720- latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
721-
722- latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
723-
724- latent_image_ids = latent_image_ids .repeat (batch_size , 1 , 1 , 1 )
725- latent_image_ids = latent_image_ids .reshape (
726- batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
727- )
728-
729- return latent_image_ids .to (device = device , dtype = dtype )
0 commit comments