Skip to content

Commit ab8f6c9

Browse files
update diffusion test with static shape
1 parent 1e37d3d commit ab8f6c9

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

optimum/intel/openvino/modeling_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,8 +1791,8 @@ def _get_ov_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tru
17911791
OV_TEXT2IMAGE_PIPELINES_MAPPING["sana-sprint"] = OVSanaSprintPipeline
17921792

17931793

1794-
if is_diffusers_version(">", "0.34.0"):
1795-
SUPPORTED_OV_PIPELINES.extend([OVFluxKontextPipeline])
1794+
if is_diffusers_version(">=", "0.34.0"):
1795+
SUPPORTED_OV_PIPELINES.append(OVFluxKontextPipeline)
17961796
OV_IMAGE2IMAGE_PIPELINES_MAPPING["flux-kontext"] = OVFluxKontextPipeline
17971797

17981798
SUPPORTED_OV_PIPELINES_MAPPINGS = [

tests/openvino/test_diffusion.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
AutoPipelineForInpainting,
2525
AutoPipelineForText2Image,
2626
DiffusionPipeline,
27-
FluxKontextPipeline,
2827
)
2928
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
3029
from diffusers.utils import load_image
@@ -485,7 +484,8 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
485484
if is_transformers_version(">=", "4.40.0"):
486485
SUPPORTED_ARCHITECTURES.append("stable-diffusion-3")
487486
SUPPORTED_ARCHITECTURES.append("flux")
488-
SUPPORTED_ARCHITECTURES.append("flux-kontext")
487+
if is_diffusers_version(">=", "0.35.0"):
488+
SUPPORTED_ARCHITECTURES.append("flux-kontext")
489489

490490
AUTOMODEL_CLASS = AutoPipelineForImage2Image
491491
OVMODEL_CLASS = OVPipelineForImage2Image
@@ -502,8 +502,9 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
502502
if model_type in ["flux", "stable-diffusion-3", "flux-kontext"]:
503503
inputs["height"] = height
504504
inputs["width"] = width
505-
506-
inputs["strength"] = 0.75
505+
506+
if model_type != "flux-kontext":
507+
inputs["strength"] = 0.75
507508

508509
return inputs
509510

@@ -535,7 +536,15 @@ def test_num_images_per_prompt(self, model_arch: str):
535536
height=height, width=width, batch_size=batch_size, model_type=model_arch
536537
)
537538
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
538-
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
539+
if model_arch != "flux-kontext":
540+
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
541+
else:
542+
if (height == width):
543+
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, 1024, 1024, 3))
544+
elif (height > width):
545+
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, 1448, 724, 3))
546+
else:
547+
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, 724, 1448, 3))
539548

540549
@parameterized.expand(["stable-diffusion", "stable-diffusion-xl", "latent-consistency"])
541550
@require_diffusers
@@ -568,8 +577,10 @@ def __call__(self, *args, **kwargs) -> None:
568577
@require_diffusers
569578
def test_shape(self, model_arch: str):
570579
pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
571-
572-
height, width, batch_size = 128, 64, 1
580+
if model_arch != "flux-kontext":
581+
height, width, batch_size = 128, 64, 1
582+
else:
583+
height, width, batch_size = 1448, 724, 1
573584

574585
for input_type in ["pil", "np", "pt"]:
575586
inputs = self.generate_inputs(
@@ -586,7 +597,7 @@ def test_shape(self, model_arch: str):
586597
elif output_type == "pt":
587598
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
588599
else:
589-
if model_arch != "flux" and model_arch != "flux-kontext":
600+
if not model_arch.startswith("flux"):
590601
out_channels = (
591602
pipeline.unet.config.out_channels
592603
if pipeline.unet is not None
@@ -611,9 +622,9 @@ def test_shape(self, model_arch: str):
611622
@require_diffusers
612623
def test_compare_to_diffusers_pipeline(self, model_arch: str):
613624
height, width, batch_size = 128, 128, 1
614-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
615-
616-
auto_cls = self.AUTOMODEL_CLASS if "flux-kontext" not in model_arch else FluxKontextPipeline
625+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
626+
auto_cls = self.AUTOMODEL_CLASS
627+
617628
diffusers_pipeline = auto_cls.from_pretrained(MODEL_NAMES[model_arch])
618629
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
619630

0 commit comments

Comments
 (0)