|  | 
| 14 | 14 | # limitations under the License. | 
| 15 | 15 | 
 | 
| 16 | 16 | import gc | 
|  | 17 | +import tempfile | 
| 17 | 18 | import unittest | 
| 18 | 19 | 
 | 
| 19 | 20 | import numpy as np | 
| @@ -212,6 +213,99 @@ def test_fused_qkv_projections(self): | 
| 212 | 213 |     def test_encode_prompt_works_in_isolation(self): | 
| 213 | 214 |         pass | 
| 214 | 215 | 
 | 
|  | 216 | +    def test_save_load_optional_components(self): | 
|  | 217 | +        components = self.get_dummy_components() | 
|  | 218 | +        pipe = self.pipeline_class(**components) | 
|  | 219 | +        pipe.to(torch_device) | 
|  | 220 | +        pipe.set_progress_bar_config(disable=None) | 
|  | 221 | + | 
|  | 222 | +        inputs = self.get_dummy_inputs(torch_device) | 
|  | 223 | + | 
|  | 224 | +        prompt = inputs["prompt"] | 
|  | 225 | +        generator = inputs["generator"] | 
|  | 226 | +        num_inference_steps = inputs["num_inference_steps"] | 
|  | 227 | +        output_type = inputs["output_type"] | 
|  | 228 | + | 
|  | 229 | +        ( | 
|  | 230 | +            prompt_embeds, | 
|  | 231 | +            negative_prompt_embeds, | 
|  | 232 | +            prompt_attention_mask, | 
|  | 233 | +            negative_prompt_attention_mask, | 
|  | 234 | +        ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0) | 
|  | 235 | + | 
|  | 236 | +        ( | 
|  | 237 | +            prompt_embeds_2, | 
|  | 238 | +            negative_prompt_embeds_2, | 
|  | 239 | +            prompt_attention_mask_2, | 
|  | 240 | +            negative_prompt_attention_mask_2, | 
|  | 241 | +        ) = pipe.encode_prompt( | 
|  | 242 | +            prompt, | 
|  | 243 | +            device=torch_device, | 
|  | 244 | +            dtype=torch.float32, | 
|  | 245 | +            text_encoder_index=1, | 
|  | 246 | +        ) | 
|  | 247 | + | 
|  | 248 | +        # inputs with prompt converted to embeddings | 
|  | 249 | +        inputs = { | 
|  | 250 | +            "prompt_embeds": prompt_embeds, | 
|  | 251 | +            "prompt_attention_mask": prompt_attention_mask, | 
|  | 252 | +            "negative_prompt_embeds": negative_prompt_embeds, | 
|  | 253 | +            "negative_prompt_attention_mask": negative_prompt_attention_mask, | 
|  | 254 | +            "prompt_embeds_2": prompt_embeds_2, | 
|  | 255 | +            "prompt_attention_mask_2": prompt_attention_mask_2, | 
|  | 256 | +            "negative_prompt_embeds_2": negative_prompt_embeds_2, | 
|  | 257 | +            "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, | 
|  | 258 | +            "generator": generator, | 
|  | 259 | +            "num_inference_steps": num_inference_steps, | 
|  | 260 | +            "output_type": output_type, | 
|  | 261 | +            "use_resolution_binning": False, | 
|  | 262 | +        } | 
|  | 263 | + | 
|  | 264 | +        # set all optional components to None | 
|  | 265 | +        for optional_component in pipe._optional_components: | 
|  | 266 | +            setattr(pipe, optional_component, None) | 
|  | 267 | + | 
|  | 268 | +        output = pipe(**inputs)[0] | 
|  | 269 | + | 
|  | 270 | +        with tempfile.TemporaryDirectory() as tmpdir: | 
|  | 271 | +            pipe.save_pretrained(tmpdir) | 
|  | 272 | +            pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) | 
|  | 273 | +            pipe_loaded.to(torch_device) | 
|  | 274 | +            pipe_loaded.set_progress_bar_config(disable=None) | 
|  | 275 | + | 
|  | 276 | +        for optional_component in pipe._optional_components: | 
|  | 277 | +            self.assertTrue( | 
|  | 278 | +                getattr(pipe_loaded, optional_component) is None, | 
|  | 279 | +                f"`{optional_component}` did not stay set to None after loading.", | 
|  | 280 | +            ) | 
|  | 281 | + | 
|  | 282 | +        inputs = self.get_dummy_inputs(torch_device) | 
|  | 283 | + | 
|  | 284 | +        generator = inputs["generator"] | 
|  | 285 | +        num_inference_steps = inputs["num_inference_steps"] | 
|  | 286 | +        output_type = inputs["output_type"] | 
|  | 287 | + | 
|  | 288 | +        # inputs with prompt converted to embeddings | 
|  | 289 | +        inputs = { | 
|  | 290 | +            "prompt_embeds": prompt_embeds, | 
|  | 291 | +            "prompt_attention_mask": prompt_attention_mask, | 
|  | 292 | +            "negative_prompt_embeds": negative_prompt_embeds, | 
|  | 293 | +            "negative_prompt_attention_mask": negative_prompt_attention_mask, | 
|  | 294 | +            "prompt_embeds_2": prompt_embeds_2, | 
|  | 295 | +            "prompt_attention_mask_2": prompt_attention_mask_2, | 
|  | 296 | +            "negative_prompt_embeds_2": negative_prompt_embeds_2, | 
|  | 297 | +            "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, | 
|  | 298 | +            "generator": generator, | 
|  | 299 | +            "num_inference_steps": num_inference_steps, | 
|  | 300 | +            "output_type": output_type, | 
|  | 301 | +            "use_resolution_binning": False, | 
|  | 302 | +        } | 
|  | 303 | + | 
|  | 304 | +        output_loaded = pipe_loaded(**inputs)[0] | 
|  | 305 | + | 
|  | 306 | +        max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() | 
|  | 307 | +        self.assertLess(max_diff, 1e-4) | 
|  | 308 | + | 
| 215 | 309 | 
 | 
| 216 | 310 | @slow | 
| 217 | 311 | @require_torch_accelerator | 
|  | 
0 commit comments