1313# limitations under the License.
1414
1515import unittest
16-
16+ import pytest
1717import numpy as np
1818import torch
1919from PIL import Image
20- from transformers import Qwen2_5_VLConfig , Qwen2_5_VLForConditionalGeneration , Qwen2Tokenizer
20+ from transformers import Qwen2_5_VLConfig , Qwen2_5_VLForConditionalGeneration , Qwen2Tokenizer , Qwen2VLProcessor
2121
2222from diffusers import (
2323 AutoencoderKLQwenImage ,
2424 FlowMatchEulerDiscreteScheduler ,
25- QwenImagePipeline ,
25+ QwenImageEditPipeline ,
2626 QwenImageTransformer2DModel ,
2727)
2828from diffusers .utils .testing_utils import enable_full_determinism , torch_device
3434enable_full_determinism ()
3535
3636
37- class QwenImagePipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
38- pipeline_class = QwenImagePipeline
37+ class QwenImageEditPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
38+ pipeline_class = QwenImageEditPipeline
3939 params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs" }
40- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
41- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
42- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
40+ batch_params = frozenset ([ "prompt" , "image" ])
41+ image_params = frozenset ([ "image" ])
42+ image_latents_params = frozenset ([ "latents" ])
4343 required_optional_params = frozenset (
4444 [
4545 "num_inference_steps" ,
@@ -56,6 +56,8 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5656 test_group_offloading = True
5757
5858 def get_dummy_components (self ):
59+ tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
60+
5961 torch .manual_seed (0 )
6062 transformer = QwenImageTransformer2DModel (
6163 patch_size = 2 ,
@@ -77,10 +79,8 @@ def get_dummy_components(self):
7779 dim_mult = [1 , 2 , 4 ],
7880 num_res_blocks = 1 ,
7981 temperal_downsample = [False , True ],
80- # fmt: off
81- latents_mean = [0.0 ] * 4 ,
82- latents_std = [1.0 ] * 4 ,
83- # fmt: on
82+ latents_mean = [0.0 ] * z_dim ,
83+ latents_std = [1.0 ] * z_dim ,
8484 )
8585
8686 torch .manual_seed (0 )
@@ -115,14 +115,15 @@ def get_dummy_components(self):
115115 vision_token_id = 151654 ,
116116 )
117117 text_encoder = Qwen2_5_VLForConditionalGeneration (config )
118- tokenizer = Qwen2Tokenizer .from_pretrained ("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" )
118+ tokenizer = Qwen2Tokenizer .from_pretrained (tiny_ckpt_id )
119119
120120 components = {
121121 "transformer" : transformer ,
122122 "vae" : vae ,
123123 "scheduler" : scheduler ,
124124 "text_encoder" : text_encoder ,
125125 "tokenizer" : tokenizer ,
126+ "processor" : Qwen2VLProcessor .from_pretrained (tiny_ckpt_id ),
126127 }
127128 return components
128129
@@ -134,7 +135,7 @@ def get_dummy_inputs(self, device, seed=0):
134135
135136 inputs = {
136137 "prompt" : "dance monkey" ,
137- "image" : Image .new ("RGB" , (16 , 16 )),
138+ "image" : Image .new ("RGB" , (32 , 32 )),
138139 "negative_prompt" : "bad quality" ,
139140 "generator" : generator ,
140141 "num_inference_steps" : 2 ,
@@ -160,13 +161,13 @@ def test_inference(self):
160161 generated_image = image [0 ]
161162 self .assertEqual (generated_image .shape , (3 , 32 , 32 ))
162163
163- # fmt: off
164- expected_slice = torch .tensor ([0.56331 , 0.63677 , 0.6015 , 0.56369 , 0.58166 , 0.55277 , 0.57176 , 0.63261 , 0.41466 , 0.35561 , 0.56229 , 0.48334 , 0.49714 , 0.52622 , 0.40872 , 0.50208 ])
164+ expected_slice = torch .tensor (
165+ [[0.5637 , 0.6341 , 0.6001 , 0.5620 , 0.5794 , 0.5498 , 0.5757 , 0.6389 , 0.4174 ,
166+ 0.3597 , 0.5649 , 0.4894 , 0.4969 , 0.5255 , 0.4083 , 0.4986 ]])
165167 # fmt: on
166168
167169 generated_slice = generated_image .flatten ()
168170 generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
169- print (f"{ generated_slice = } " )
170171 self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
171172
172173 def test_inference_batch_single_identical (self ):
@@ -236,3 +237,7 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
236237 expected_diff_max ,
237238 "VAE tiling should not affect the inference results" ,
238239 )
240+
241+ @pytest .mark .xfail (condition = True , reason = "Preconfigured embeddings need to be revisited." , strict = True )
242+ def test_encode_prompt_works_in_isolation (self , extra_required_param_value_dict = None , atol = 1e-4 , rtol = 1e-4 ):
243+ super ().test_encode_prompt_works_in_isolation (extra_required_param_value_dict , atol , rtol )
0 commit comments