|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import unittest |
| 16 | +from pathlib import Path |
16 | 17 |
|
17 | 18 | import numpy as np |
18 | 19 | import pytest |
|
35 | 36 | OVPipelineForInpainting, |
36 | 37 | OVPipelineForText2Image, |
37 | 38 | ) |
| 39 | +from optimum.intel.openvino.utils import TemporaryDirectory |
38 | 40 | from optimum.intel.utils.import_utils import is_transformers_version |
39 | 41 | from optimum.utils.testing_utils import require_diffusers |
40 | 42 |
|
@@ -309,6 +311,31 @@ def test_safety_checker(self, model_arch: str): |
309 | 311 |
|
310 | 312 | np.testing.assert_allclose(ov_images, diffusers_images, atol=1e-4, rtol=1e-2) |
311 | 313 |
|
| 314 | + @require_diffusers |
| 315 | + def test_load_and_save_pipeline_with_safety_checker(self): |
| 316 | + model_id = "katuni4ka/tiny-random-stable-diffusion-with-safety-checker" |
| 317 | + ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id) |
| 318 | + self.assertTrue(ov_pipeline.safety_checker is not None) |
| 319 | + self.assertIsInstance(ov_pipeline.safety_checker, StableDiffusionSafetyChecker) |
| 320 | + with TemporaryDirectory() as tmpdirname: |
| 321 | + ov_pipeline.save_pretrained(tmpdirname) |
| 322 | + for subdir in [ |
| 323 | + "text_encoder", |
| 324 | + "tokenizer", |
| 325 | + "unet", |
| 326 | + "vae_encoder", |
| 327 | + "vae_decoder", |
| 328 | + "scheduler", |
| 329 | + "feature_extractor", |
| 330 | + ]: |
| 331 | + subdir_path = Path(tmpdirname) / subdir |
| 332 | + self.assertTrue(subdir_path.is_dir()) |
| 333 | + loaded_pipeline = self.OVMODEL_CLASS.from_pretrained(tmpdirname) |
| 334 | + self.assertTrue(loaded_pipeline.safety_checker is not None) |
| 335 | + self.assertIsInstance(loaded_pipeline.safety_checker, StableDiffusionSafetyChecker) |
| 336 | + del loaded_pipeline |
| 337 | + del ov_pipeline |
| 338 | + |
312 | 339 | @parameterized.expand(SUPPORTED_ARCHITECTURES) |
313 | 340 | def test_height_width_properties(self, model_arch: str): |
314 | 341 | batch_size, height, width, num_images_per_prompt = 2, 128, 64, 4 |
|
0 commit comments