Skip to content

Commit 790244d

Browse files
authored
add saving safety_checker (#990)
* add saving safety_checker during conversion * add safety_checker to save_pretrained * add test * Update modeling_diffusion.py
1 parent b3cbc95 commit 790244d

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,9 @@ def export_from_model(
740740
tokenizer_3 = getattr(model, "tokenizer_3", None)
741741
if tokenizer_3 is not None:
742742
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))
743+
safety_checker = getattr(model, "safety_checker", None)
744+
if safety_checker is not None:
745+
safety_checker.save_pretrained(output.joinpath("safety_checker"))
743746

744747
model.save_config(output)
745748

optimum/intel/openvino/modeling_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
295295
self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3")
296296
if self.feature_extractor is not None:
297297
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
298+
if getattr(self, "safety_checker", None) is not None:
299+
self.safety_checker.save_pretrained(save_directory / "safety_checker")
298300

299301
self._save_openvino_config(save_directory)
300302

tests/openvino/test_diffusion.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from pathlib import Path
1617

1718
import numpy as np
1819
import pytest
@@ -35,6 +36,7 @@
3536
OVPipelineForInpainting,
3637
OVPipelineForText2Image,
3738
)
39+
from optimum.intel.openvino.utils import TemporaryDirectory
3840
from optimum.intel.utils.import_utils import is_transformers_version
3941
from optimum.utils.testing_utils import require_diffusers
4042

@@ -309,6 +311,31 @@ def test_safety_checker(self, model_arch: str):
309311

310312
np.testing.assert_allclose(ov_images, diffusers_images, atol=1e-4, rtol=1e-2)
311313

314+
@require_diffusers
315+
def test_load_and_save_pipeline_with_safety_checker(self):
316+
model_id = "katuni4ka/tiny-random-stable-diffusion-with-safety-checker"
317+
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id)
318+
self.assertTrue(ov_pipeline.safety_checker is not None)
319+
self.assertIsInstance(ov_pipeline.safety_checker, StableDiffusionSafetyChecker)
320+
with TemporaryDirectory() as tmpdirname:
321+
ov_pipeline.save_pretrained(tmpdirname)
322+
for subdir in [
323+
"text_encoder",
324+
"tokenizer",
325+
"unet",
326+
"vae_encoder",
327+
"vae_decoder",
328+
"scheduler",
329+
"feature_extractor",
330+
]:
331+
subdir_path = Path(tmpdirname) / subdir
332+
self.assertTrue(subdir_path.is_dir())
333+
loaded_pipeline = self.OVMODEL_CLASS.from_pretrained(tmpdirname)
334+
self.assertTrue(loaded_pipeline.safety_checker is not None)
335+
self.assertIsInstance(loaded_pipeline.safety_checker, StableDiffusionSafetyChecker)
336+
del loaded_pipeline
337+
del ov_pipeline
338+
312339
@parameterized.expand(SUPPORTED_ARCHITECTURES)
313340
def test_height_width_properties(self, model_arch: str):
314341
batch_size, height, width, num_images_per_prompt = 2, 128, 64, 4

0 commit comments

Comments
 (0)