@@ -225,19 +225,28 @@ def __init__(
225225 transformer = transformer ,
226226 scheduler = scheduler ,
227227 )
228- self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
228+ self .vae_scale_factor = (
229+ 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
230+ )
231+ latent_channels = self .vae .config .latent_channels if hasattr (self , "vae" ) and self .vae is not None else 16
229232 self .image_processor = VaeImageProcessor (
230- vae_scale_factor = self .vae_scale_factor , vae_latent_channels = self . vae . config . latent_channels
233+ vae_scale_factor = self .vae_scale_factor , vae_latent_channels = latent_channels
231234 )
232235 self .mask_processor = VaeImageProcessor (
233236 vae_scale_factor = self .vae_scale_factor ,
234- vae_latent_channels = self . vae . config . latent_channels ,
237+ vae_latent_channels = latent_channels ,
235238 do_normalize = False ,
236239 do_binarize = True ,
237240 do_convert_grayscale = True ,
238241 )
239- self .tokenizer_max_length = self .tokenizer .model_max_length
240- self .default_sample_size = self .transformer .config .sample_size
242+ self .tokenizer_max_length = (
243+ self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
244+ )
245+ self .default_sample_size = (
246+ self .transformer .config .sample_size
247+ if hasattr (self , "transformer" ) and self .transformer is not None
248+ else 128
249+ )
241250 self .patch_size = (
242251 self .transformer .config .patch_size if hasattr (self , "transformer" ) and self .transformer is not None else 2
243252 )
0 commit comments