@@ -218,8 +218,8 @@ def test_shape(self, model_arch: str):
218218 ),
219219 )
220220 else :
221- packed_height = height // pipeline .vae_scale_factor
222- packed_width = width // pipeline .vae_scale_factor
221+ packed_height = height // pipeline .vae_scale_factor // 2
222+ packed_width = width // pipeline .vae_scale_factor // 2
223223 channels = pipeline .transformer .config .in_channels
224224 self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
225225
@@ -426,7 +426,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
426426 height = height , width = width , batch_size = batch_size , channel = channel , input_type = input_type
427427 )
428428
429- if "flux" == model_type :
429+ if model_type in [ "flux" , "stable-diffusion-3" ] :
430430 inputs ["height" ] = height
431431 inputs ["width" ] = width
432432
@@ -529,8 +529,8 @@ def test_shape(self, model_arch: str):
529529 ),
530530 )
531531 else :
532- packed_height = height // pipeline .vae_scale_factor
533- packed_width = width // pipeline .vae_scale_factor
532+ packed_height = height // pipeline .vae_scale_factor // 2
533+ packed_width = width // pipeline .vae_scale_factor // 2
534534 channels = pipeline .transformer .config .in_channels
535535 self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
536536
@@ -780,8 +780,8 @@ def test_shape(self, model_arch: str):
780780 ),
781781 )
782782 else :
783- packed_height = height // pipeline .vae_scale_factor
784- packed_width = width // pipeline .vae_scale_factor
783+ packed_height = height // pipeline .vae_scale_factor // 2
784+ packed_width = width // pipeline .vae_scale_factor // 2
785785 channels = pipeline .transformer .config .in_channels
786786 self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
787787
0 commit comments