@@ -78,7 +78,7 @@ class OVPipelineForText2ImageTest(unittest.TestCase):
7878 NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES = ["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]
7979 if is_transformers_version (">=" , "4.40.0" ):
8080 SUPPORTED_ARCHITECTURES .extend (["stable-diffusion-3" , "flux" , "sana" ])
81- NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES .append (["stable-diffusion-3" ])
81+ NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES .extend (["stable-diffusion-3" ])
8282 CALLBACK_SUPPORT_ARCHITECTURES = ["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]
8383
8484 OVMODEL_CLASS = OVPipelineForText2Image
@@ -94,13 +94,6 @@ def generate_inputs(self, height=128, width=128, batch_size=1):
9494
9595 return inputs
9696
97- def get_auto_cls (self , model_arch ):
98- if model_arch == "sana" :
99- from diffusers import SanaPipeline
100-
101- return SanaPipeline
102- return self .AUTOMODEL_CLASS
103-
10497 @require_diffusers
10598 def test_load_vanilla_model_which_is_not_supported (self ):
10699 with self .assertRaises (Exception ) as context :
@@ -111,9 +104,7 @@ def test_load_vanilla_model_which_is_not_supported(self):
111104 @parameterized .expand (SUPPORTED_ARCHITECTURES )
112105 @require_diffusers
113106 def test_ov_pipeline_class_dispatch (self , model_arch : str ):
114- auto_cls = self .get_auto_cls (model_arch )
115- auto_pipeline = DiffusionPipeline if model_arch != "sana" else auto_cls
116- auto_pipeline = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
107+ auto_pipeline = DiffusionPipeline .from_pretrained (MODEL_NAMES [model_arch ])
117108 ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
118109
119110 self .assertEqual (ov_pipeline .auto_model_class , auto_pipeline .__class__ )
@@ -141,21 +132,19 @@ def test_num_images_per_prompt(self, model_arch: str):
141132 def test_compare_to_diffusers_pipeline (self , model_arch : str ):
142133 height , width , batch_size = 64 , 64 , 1
143134 inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
144- auto_cls = self .get_auto_cls (model_arch )
145135 ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
146- diffusers_pipeline = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
136+ diffusers_pipeline = DiffusionPipeline .from_pretrained (MODEL_NAMES [model_arch ])
147137
148- with torch .no_grad ():
149- for output_type in ["latent" , "np" , "pt" ]:
150- inputs ["output_type" ] = output_type
151- if model_arch == "sana" :
152- # resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
153- inputs ["use_resolution_binning" ] = False
154- atol = 1e-4
138+ for output_type in ["latent" , "np" , "pt" ]:
139+ inputs ["output_type" ] = output_type
140+ if model_arch == "sana" :
141+ # resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
142+ inputs ["use_resolution_binning" ] = False
143+ atol = 1e-4
155144
156- ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
157- diffusers_output = diffusers_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
158- np .testing .assert_allclose (ov_output , diffusers_output , atol = atol , rtol = 1e-2 )
145+ ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
146+ diffusers_output = diffusers_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
147+ np .testing .assert_allclose (ov_output , diffusers_output , atol = atol , rtol = 1e-2 )
159148
160149 # test on inputs nondivisible on 64
161150 height , width , batch_size = 96 , 96 , 1
@@ -191,8 +180,7 @@ def __call__(self, *args, **kwargs) -> None:
191180 auto_callback = Callback ()
192181
193182 ov_pipe = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
194- auto_cls = self .get_auto_cls (model_arch )
195- auto_pipe = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
183+ auto_pipe = DiffusionPipeline .from_pretrained (MODEL_NAMES [model_arch ])
196184
197185 # callback_steps=1 to trigger callback every step
198186 ov_pipe (** inputs , callback = ov_callback , callback_steps = 1 )
0 commit comments