2929PipelineImageInput = Union [
3030 PIL .Image .Image ,
3131 np .ndarray ,
32- torch .FloatTensor ,
32+ torch .Tensor ,
3333 List [PIL .Image .Image ],
3434 List [np .ndarray ],
35- List [torch .FloatTensor ],
35+ List [torch .Tensor ],
3636]
3737
3838PipelineDepthInput = PipelineImageInput
3939
4040
41+ def is_valid_image (image ):
42+ return isinstance (image , PIL .Image .Image ) or isinstance (image , (np .ndarray , torch .Tensor )) and image .ndim in (2 , 3 )
43+
44+
45+ def is_valid_image_imagelist (images ):
46+ # check if the image input is one of the supported formats for image and image list:
47+ # it can be either one of below 3
48+ # (1) a 4d pytorch tensor or numpy array,
49+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
50+ # (3) a list of valid image
51+ if isinstance (images , (np .ndarray , torch .Tensor )) and images .ndim == 4 :
52+ return True
53+ elif is_valid_image (images ):
54+ return True
55+ elif isinstance (images , list ):
56+ return all (is_valid_image (image ) for image in images )
57+ return False
58+
59+
4160class VaeImageProcessor (ConfigMixin ):
4261 """
4362 Image processor for VAE.
@@ -110,7 +129,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd
110129 return images
111130
112131 @staticmethod
113- def numpy_to_pt (images : np .ndarray ) -> torch .FloatTensor :
132+ def numpy_to_pt (images : np .ndarray ) -> torch .Tensor :
114133 """
115134 Convert a NumPy image to a PyTorch tensor.
116135 """
@@ -121,7 +140,7 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
121140 return images
122141
123142 @staticmethod
124- def pt_to_numpy (images : torch .FloatTensor ) -> np .ndarray :
143+ def pt_to_numpy (images : torch .Tensor ) -> np .ndarray :
125144 """
126145 Convert a PyTorch tensor to a NumPy image.
127146 """
@@ -497,12 +516,27 @@ def preprocess(
497516 else :
498517 image = np .expand_dims (image , axis = - 1 )
499518
500- if isinstance (image , supported_formats ):
501- image = [image ]
502- elif not (isinstance (image , list ) and all (isinstance (i , supported_formats ) for i in image )):
519+ if isinstance (image , list ) and isinstance (image [0 ], np .ndarray ) and image [0 ].ndim == 4 :
520+ warnings .warn (
521+ "Passing `image` as a list of 4d np.ndarray is deprecated."
522+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray" ,
523+ FutureWarning ,
524+ )
525+ image = np .concatenate (image , axis = 0 )
526+ if isinstance (image , list ) and isinstance (image [0 ], torch .Tensor ) and image [0 ].ndim == 4 :
527+ warnings .warn (
528+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
529+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor" ,
530+ FutureWarning ,
531+ )
532+ image = torch .cat (image , axis = 0 )
533+
534+ if not is_valid_image_imagelist (image ):
503535 raise ValueError (
504- f"Input is in incorrect format: { [ type ( i ) for i in image ] } . Currently, we only support { ', ' .join (supported_formats )} "
536+ f"Input is in incorrect format. Currently, we only support { ', ' .join (supported_formats )} "
505537 )
538+ if not isinstance (image , list ):
539+ image = [image ]
506540
507541 if isinstance (image [0 ], PIL .Image .Image ):
508542 if crops_coords is not None :
@@ -561,15 +595,15 @@ def preprocess(
561595
562596 def postprocess (
563597 self ,
564- image : torch .FloatTensor ,
598+ image : torch .Tensor ,
565599 output_type : str = "pil" ,
566600 do_denormalize : Optional [List [bool ]] = None ,
567- ) -> Union [PIL .Image .Image , np .ndarray , torch .FloatTensor ]:
601+ ) -> Union [PIL .Image .Image , np .ndarray , torch .Tensor ]:
568602 """
569603 Postprocess the image output from tensor to `output_type`.
570604
571605 Args:
572- image (`torch.FloatTensor `):
606+ image (`torch.Tensor `):
573607 The image input, should be a pytorch tensor with shape `B x C x H x W`.
574608 output_type (`str`, *optional*, defaults to `pil`):
575609 The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -578,7 +612,7 @@ def postprocess(
578612 `VaeImageProcessor` config.
579613
580614 Returns:
581- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor `:
615+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor `:
582616 The postprocessed image.
583617 """
584618 if not isinstance (image , torch .Tensor ):
@@ -738,15 +772,15 @@ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
738772
739773 def postprocess (
740774 self ,
741- image : torch .FloatTensor ,
775+ image : torch .Tensor ,
742776 output_type : str = "pil" ,
743777 do_denormalize : Optional [List [bool ]] = None ,
744- ) -> Union [PIL .Image .Image , np .ndarray , torch .FloatTensor ]:
778+ ) -> Union [PIL .Image .Image , np .ndarray , torch .Tensor ]:
745779 """
746780 Postprocess the image output from tensor to `output_type`.
747781
748782 Args:
749- image (`torch.FloatTensor `):
783+ image (`torch.Tensor `):
750784 The image input, should be a pytorch tensor with shape `B x C x H x W`.
751785 output_type (`str`, *optional*, defaults to `pil`):
752786 The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -755,7 +789,7 @@ def postprocess(
755789 `VaeImageProcessor` config.
756790
757791 Returns:
758- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor `:
792+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor `:
759793 The postprocessed image.
760794 """
761795 if not isinstance (image , torch .Tensor ):
@@ -793,8 +827,8 @@ def postprocess(
793827
794828 def preprocess (
795829 self ,
796- rgb : Union [torch .FloatTensor , PIL .Image .Image , np .ndarray ],
797- depth : Union [torch .FloatTensor , PIL .Image .Image , np .ndarray ],
830+ rgb : Union [torch .Tensor , PIL .Image .Image , np .ndarray ],
831+ depth : Union [torch .Tensor , PIL .Image .Image , np .ndarray ],
798832 height : Optional [int ] = None ,
799833 width : Optional [int ] = None ,
800834 target_res : Optional [int ] = None ,
@@ -933,13 +967,13 @@ def __init__(
933967 )
934968
935969 @staticmethod
936- def downsample (mask : torch .FloatTensor , batch_size : int , num_queries : int , value_embed_dim : int ):
970+ def downsample (mask : torch .Tensor , batch_size : int , num_queries : int , value_embed_dim : int ):
937971 """
938972 Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
939973 aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
940974
941975 Args:
942- mask (`torch.FloatTensor `):
976+ mask (`torch.Tensor `):
943977 The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
944978 batch_size (`int`):
945979 The batch size.
@@ -949,7 +983,7 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value
949983 The dimensionality of the value embeddings.
950984
951985 Returns:
952- `torch.FloatTensor `:
986+ `torch.Tensor `:
953987 The downsampled mask tensor.
954988
955989 """
0 commit comments