@@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
6262 bool do_normalize = true ,
6363 bool do_binarize = false ,
6464 bool do_convert_rgb = false ,
65- bool do_convert_grayscale = false ) {
65+ bool do_convert_grayscale = false ,
66+ int64_t latent_channels = 4 ) {
6667 const auto & model_args = context.get_model_args ();
68+ dtype_ = context.get_tensor_options ().dtype ().toScalarType ();
6769 scale_factor_ = 1 << model_args.block_out_channels ().size ();
68- latent_channels_ = 4 ;
70+ latent_channels_ = latent_channels ;
6971 do_resize_ = do_resize;
7072 do_normalize_ = do_normalize;
7173 do_binarize_ = do_binarize;
@@ -86,8 +88,29 @@ class VAEImageProcessorImpl : public torch::nn::Module {
8688 std::optional<int64_t > width = std::nullopt ,
8789 const std::string& resize_mode = " default" ,
8890 std::optional<std::tuple<int64_t , int64_t , int64_t , int64_t >>
89- crop_coords = std::nullopt ) {
91+ crop_coords = std::nullopt ,
92+ const bool is_pil_image = false ) {
9093 torch::Tensor processed = image.clone ();
94+ if (is_pil_image == true ) {
95+ auto dims = processed.dim ();
96+ if (dims < 2 || dims > 4 ) {
97+ LOG (FATAL) << " Unsupported PIL image dimension: " << dims;
98+ }
99+ if (dims == 4 ) {
100+ if (processed.size (1 ) == 3 || processed.size (1 ) == 1 ) {
101+ processed = processed.permute ({0 , 2 , 3 , 1 });
102+ }
103+ processed = processed.squeeze (0 );
104+ dims = processed.dim ();
105+ }
106+ processed = processed.to (torch::kFloat );
107+ processed = processed / 255 .0f ;
108+ if (dims == 2 ) {
109+ processed = processed.unsqueeze (0 ).unsqueeze (0 );
110+ } else {
111+ processed = processed.permute ({2 , 0 , 1 }).unsqueeze (0 );
112+ }
113+ }
91114 if (processed.dtype () != torch::kFloat32 ) {
92115 processed = processed.to (torch::kFloat32 );
93116 }
@@ -116,7 +139,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
116139 if (channel == latent_channels_) {
117140 return image;
118141 }
119-
120142 auto [target_h, target_w] =
121143 get_default_height_width (processed, height, width);
122144 if (do_resize_) {
@@ -129,7 +151,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
129151 if (do_binarize_) {
130152 processed = (processed >= 0 .5f ).to (torch::kFloat32 );
131153 }
132- processed = processed.to (image. dtype () );
154+ processed = processed.to (dtype_ );
133155 return processed;
134156 }
135157
@@ -202,6 +224,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
202224 bool do_binarize_ = false ;
203225 bool do_convert_rgb_ = false ;
204226 bool do_convert_grayscale_ = false ;
227+ torch::ScalarType dtype_ = torch::kFloat32 ;
205228};
206229TORCH_MODULE (VAEImageProcessor);
207230
0 commit comments