| 
 | 1 | +# coding=utf-8  | 
 | 2 | +# Copyright 2024 The HuggingFace Team Inc.  | 
 | 3 | +#  | 
 | 4 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 5 | +# you may not use this file except in compliance with the License.  | 
 | 6 | +# You may obtain a clone of the License at  | 
 | 7 | +#  | 
 | 8 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 9 | +#  | 
 | 10 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 11 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 13 | +# See the License for the specific language governing permissions and  | 
 | 14 | +# limitations under the License.  | 
 | 15 | +import gc  | 
 | 16 | +import unittest  | 
 | 17 | + | 
 | 18 | +import torch  | 
 | 19 | + | 
 | 20 | +from diffusers import DiffusionPipeline  | 
 | 21 | +from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +@require_torch_gpu  | 
 | 25 | +@slow  | 
 | 26 | +class QuantCompileTests(unittest.TestCase):  | 
 | 27 | +    quantization_config = None  | 
 | 28 | + | 
 | 29 | +    def setUp(self):  | 
 | 30 | +        super().setUp()  | 
 | 31 | +        gc.collect()  | 
 | 32 | +        backend_empty_cache(torch_device)  | 
 | 33 | +        torch.compiler.reset()  | 
 | 34 | + | 
 | 35 | +    def tearDown(self):  | 
 | 36 | +        super().tearDown()  | 
 | 37 | +        gc.collect()  | 
 | 38 | +        backend_empty_cache(torch_device)  | 
 | 39 | +        torch.compiler.reset()  | 
 | 40 | + | 
 | 41 | +    def _init_pipeline(self, quantization_config, torch_dtype):  | 
 | 42 | +        pipe = DiffusionPipeline.from_pretrained(  | 
 | 43 | +            "stabilityai/stable-diffusion-3-medium-diffusers",  | 
 | 44 | +            quantization_config=quantization_config,  | 
 | 45 | +            torch_dtype=torch_dtype,  | 
 | 46 | +        )  | 
 | 47 | +        return pipe  | 
 | 48 | + | 
 | 49 | +    def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):  | 
 | 50 | +        pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")  | 
 | 51 | +        # import to ensure fullgraph True  | 
 | 52 | +        pipe.transformer.compile(fullgraph=True)  | 
 | 53 | + | 
 | 54 | +        for _ in range(2):  | 
 | 55 | +            # small resolutions to ensure speedy execution.  | 
 | 56 | +            pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)  | 
 | 57 | + | 
 | 58 | +    def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):  | 
 | 59 | +        pipe = self._init_pipeline(quantization_config, torch_dtype)  | 
 | 60 | +        pipe.enable_model_cpu_offload()  | 
 | 61 | +        pipe.transformer.compile()  | 
 | 62 | + | 
 | 63 | +        for _ in range(2):  | 
 | 64 | +            # small resolutions to ensure speedy execution.  | 
 | 65 | +            pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)  | 
 | 66 | + | 
 | 67 | +    def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):  | 
 | 68 | +        torch._dynamo.config.cache_size_limit = 10000  | 
 | 69 | + | 
 | 70 | +        pipe = self._init_pipeline(quantization_config, torch_dtype)  | 
 | 71 | +        group_offload_kwargs = {  | 
 | 72 | +            "onload_device": torch.device("cuda"),  | 
 | 73 | +            "offload_device": torch.device("cpu"),  | 
 | 74 | +            "offload_type": "leaf_level",  | 
 | 75 | +            "use_stream": True,  | 
 | 76 | +            "non_blocking": True,  | 
 | 77 | +        }  | 
 | 78 | +        pipe.transformer.enable_group_offload(**group_offload_kwargs)  | 
 | 79 | +        pipe.transformer.compile()  | 
 | 80 | +        for name, component in pipe.components.items():  | 
 | 81 | +            if name != "transformer" and isinstance(component, torch.nn.Module):  | 
 | 82 | +                if torch.device(component.device).type == "cpu":  | 
 | 83 | +                    component.to("cuda")  | 
 | 84 | + | 
 | 85 | +        for _ in range(2):  | 
 | 86 | +            # small resolutions to ensure speedy execution.  | 
 | 87 | +            pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)  | 
0 commit comments