@@ -225,19 +225,28 @@ def __init__(
225
225
transformer = transformer ,
226
226
scheduler = scheduler ,
227
227
)
228
- self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
228
+ self .vae_scale_factor = (
229
+ 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
230
+ )
231
+ latent_channels = self .vae .config .latent_channels if hasattr (self , "vae" ) and self .vae is not None else 16
229
232
self .image_processor = VaeImageProcessor (
230
- vae_scale_factor = self .vae_scale_factor , vae_latent_channels = self . vae . config . latent_channels
233
+ vae_scale_factor = self .vae_scale_factor , vae_latent_channels = latent_channels
231
234
)
232
235
self .mask_processor = VaeImageProcessor (
233
236
vae_scale_factor = self .vae_scale_factor ,
234
- vae_latent_channels = self . vae . config . latent_channels ,
237
+ vae_latent_channels = latent_channels ,
235
238
do_normalize = False ,
236
239
do_binarize = True ,
237
240
do_convert_grayscale = True ,
238
241
)
239
- self .tokenizer_max_length = self .tokenizer .model_max_length
240
- self .default_sample_size = self .transformer .config .sample_size
242
+ self .tokenizer_max_length = (
243
+ self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
244
+ )
245
+ self .default_sample_size = (
246
+ self .transformer .config .sample_size
247
+ if hasattr (self , "transformer" ) and self .transformer is not None
248
+ else 128
249
+ )
241
250
self .patch_size = (
242
251
self .transformer .config .patch_size if hasattr (self , "transformer" ) and self .transformer is not None else 2
243
252
)
0 commit comments