|  | 
| 14 | 14 | # limitations under the License. | 
| 15 | 15 | 
 | 
| 16 | 16 | import inspect | 
|  | 17 | +import tempfile | 
| 17 | 18 | import unittest | 
| 18 | 19 | 
 | 
| 19 | 20 | import numpy as np | 
|  | 
| 27 | 28 |     HunyuanDiTPAGPipeline, | 
| 28 | 29 |     HunyuanDiTPipeline, | 
| 29 | 30 | ) | 
| 30 |  | -from diffusers.utils.testing_utils import ( | 
| 31 |  | -    enable_full_determinism, | 
| 32 |  | -) | 
|  | 31 | +from diffusers.utils.testing_utils import enable_full_determinism, torch_device | 
| 33 | 32 | 
 | 
| 34 | 33 | from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS | 
| 35 | 34 | from ..test_pipelines_common import PipelineTesterMixin, to_np | 
| @@ -269,3 +268,96 @@ def test_pag_applied_layers(self): | 
| 269 | 268 |     ) | 
| 270 | 269 |     def test_encode_prompt_works_in_isolation(self): | 
| 271 | 270 |         pass | 
|  | 271 | + | 
|  | 272 | +    def test_save_load_optional_components(self): | 
|  | 273 | +        components = self.get_dummy_components() | 
|  | 274 | +        pipe = self.pipeline_class(**components) | 
|  | 275 | +        pipe.to(torch_device) | 
|  | 276 | +        pipe.set_progress_bar_config(disable=None) | 
|  | 277 | + | 
|  | 278 | +        inputs = self.get_dummy_inputs(torch_device) | 
|  | 279 | + | 
|  | 280 | +        prompt = inputs["prompt"] | 
|  | 281 | +        generator = inputs["generator"] | 
|  | 282 | +        num_inference_steps = inputs["num_inference_steps"] | 
|  | 283 | +        output_type = inputs["output_type"] | 
|  | 284 | + | 
|  | 285 | +        ( | 
|  | 286 | +            prompt_embeds, | 
|  | 287 | +            negative_prompt_embeds, | 
|  | 288 | +            prompt_attention_mask, | 
|  | 289 | +            negative_prompt_attention_mask, | 
|  | 290 | +        ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0) | 
|  | 291 | + | 
|  | 292 | +        ( | 
|  | 293 | +            prompt_embeds_2, | 
|  | 294 | +            negative_prompt_embeds_2, | 
|  | 295 | +            prompt_attention_mask_2, | 
|  | 296 | +            negative_prompt_attention_mask_2, | 
|  | 297 | +        ) = pipe.encode_prompt( | 
|  | 298 | +            prompt, | 
|  | 299 | +            device=torch_device, | 
|  | 300 | +            dtype=torch.float32, | 
|  | 301 | +            text_encoder_index=1, | 
|  | 302 | +        ) | 
|  | 303 | + | 
|  | 304 | +        # inputs with prompt converted to embeddings | 
|  | 305 | +        inputs = { | 
|  | 306 | +            "prompt_embeds": prompt_embeds, | 
|  | 307 | +            "prompt_attention_mask": prompt_attention_mask, | 
|  | 308 | +            "negative_prompt_embeds": negative_prompt_embeds, | 
|  | 309 | +            "negative_prompt_attention_mask": negative_prompt_attention_mask, | 
|  | 310 | +            "prompt_embeds_2": prompt_embeds_2, | 
|  | 311 | +            "prompt_attention_mask_2": prompt_attention_mask_2, | 
|  | 312 | +            "negative_prompt_embeds_2": negative_prompt_embeds_2, | 
|  | 313 | +            "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, | 
|  | 314 | +            "generator": generator, | 
|  | 315 | +            "num_inference_steps": num_inference_steps, | 
|  | 316 | +            "output_type": output_type, | 
|  | 317 | +            "use_resolution_binning": False, | 
|  | 318 | +        } | 
|  | 319 | + | 
|  | 320 | +        # set all optional components to None | 
|  | 321 | +        for optional_component in pipe._optional_components: | 
|  | 322 | +            setattr(pipe, optional_component, None) | 
|  | 323 | + | 
|  | 324 | +        output = pipe(**inputs)[0] | 
|  | 325 | + | 
|  | 326 | +        with tempfile.TemporaryDirectory() as tmpdir: | 
|  | 327 | +            pipe.save_pretrained(tmpdir) | 
|  | 328 | +            pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) | 
|  | 329 | +            pipe_loaded.to(torch_device) | 
|  | 330 | +            pipe_loaded.set_progress_bar_config(disable=None) | 
|  | 331 | + | 
|  | 332 | +        for optional_component in pipe._optional_components: | 
|  | 333 | +            self.assertTrue( | 
|  | 334 | +                getattr(pipe_loaded, optional_component) is None, | 
|  | 335 | +                f"`{optional_component}` did not stay set to None after loading.", | 
|  | 336 | +            ) | 
|  | 337 | + | 
|  | 338 | +        inputs = self.get_dummy_inputs(torch_device) | 
|  | 339 | + | 
|  | 340 | +        generator = inputs["generator"] | 
|  | 341 | +        num_inference_steps = inputs["num_inference_steps"] | 
|  | 342 | +        output_type = inputs["output_type"] | 
|  | 343 | + | 
|  | 344 | +        # inputs with prompt converted to embeddings | 
|  | 345 | +        inputs = { | 
|  | 346 | +            "prompt_embeds": prompt_embeds, | 
|  | 347 | +            "prompt_attention_mask": prompt_attention_mask, | 
|  | 348 | +            "negative_prompt_embeds": negative_prompt_embeds, | 
|  | 349 | +            "negative_prompt_attention_mask": negative_prompt_attention_mask, | 
|  | 350 | +            "prompt_embeds_2": prompt_embeds_2, | 
|  | 351 | +            "prompt_attention_mask_2": prompt_attention_mask_2, | 
|  | 352 | +            "negative_prompt_embeds_2": negative_prompt_embeds_2, | 
|  | 353 | +            "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, | 
|  | 354 | +            "generator": generator, | 
|  | 355 | +            "num_inference_steps": num_inference_steps, | 
|  | 356 | +            "output_type": output_type, | 
|  | 357 | +            "use_resolution_binning": False, | 
|  | 358 | +        } | 
|  | 359 | + | 
|  | 360 | +        output_loaded = pipe_loaded(**inputs)[0] | 
|  | 361 | + | 
|  | 362 | +        max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() | 
|  | 363 | +        self.assertLess(max_diff, 1e-4) | 
0 commit comments