|  | 
|  | 1 | +import gc | 
|  | 2 | +import tempfile | 
|  | 3 | +import unittest | 
|  | 4 | + | 
|  | 5 | +from diffusers import SanaPipeline, SanaTransformer2DModel, NVIDIAModelOptConfig | 
|  | 6 | +from diffusers.utils import is_nvidia_modelopt_available, is_torch_available | 
|  | 7 | +from diffusers.utils.testing_utils import ( | 
|  | 8 | +    backend_empty_cache, | 
|  | 9 | +    backend_reset_peak_memory_stats, | 
|  | 10 | +    enable_full_determinism, | 
|  | 11 | +    nightly, | 
|  | 12 | +    numpy_cosine_similarity_distance, | 
|  | 13 | +    require_accelerate, | 
|  | 14 | +    require_big_accelerator, | 
|  | 15 | +    require_torch_cuda_compatibility, | 
|  | 16 | +    torch_device, | 
|  | 17 | +) | 
|  | 18 | + | 
|  | 19 | +if is_nvidia_modelopt_available(): | 
|  | 20 | +    import modelopt.torch.quantization as mtq | 
|  | 21 | + | 
|  | 22 | +if is_torch_available(): | 
|  | 23 | +    import torch | 
|  | 24 | +    from ..utils import LoRALayer, get_memory_consumption_stat | 
|  | 25 | + | 
|  | 26 | +enable_full_determinism() | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +@nightly | 
|  | 30 | +@require_big_accelerator | 
|  | 31 | +@require_accelerate | 
|  | 32 | +class ModelOptBaseTesterMixin: | 
|  | 33 | +    model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" | 
|  | 34 | +    model_cls = SanaTransformer2DModel | 
|  | 35 | +    pipeline_cls = SanaPipeline | 
|  | 36 | +    torch_dtype = torch.bfloat16 | 
|  | 37 | +    expected_memory_reduction = 0.0 | 
|  | 38 | +    keep_in_fp32_module = "" | 
|  | 39 | +    modules_to_not_convert = "" | 
|  | 40 | +    _test_torch_compile = False | 
|  | 41 | + | 
|  | 42 | +    def setUp(self): | 
|  | 43 | +        backend_reset_peak_memory_stats(torch_device) | 
|  | 44 | +        backend_empty_cache(torch_device) | 
|  | 45 | +        gc.collect() | 
|  | 46 | + | 
|  | 47 | +    def tearDown(self): | 
|  | 48 | +        backend_reset_peak_memory_stats(torch_device) | 
|  | 49 | +        backend_empty_cache(torch_device) | 
|  | 50 | +        gc.collect() | 
|  | 51 | + | 
|  | 52 | +    def get_dummy_init_kwargs(self): | 
|  | 53 | +        return {"quant_type": "FP8"} | 
|  | 54 | + | 
|  | 55 | +    def get_dummy_model_init_kwargs(self): | 
|  | 56 | +        return { | 
|  | 57 | +            "pretrained_model_name_or_path": self.model_id, | 
|  | 58 | +            "torch_dtype": self.torch_dtype, | 
|  | 59 | +            "quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()), | 
|  | 60 | +            "subfolder": "transformer", | 
|  | 61 | +        } | 
|  | 62 | + | 
|  | 63 | +    def test_modelopt_layers(self): | 
|  | 64 | +        model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | 
|  | 65 | +        for name, module in model.named_modules(): | 
|  | 66 | +            if isinstance(module, torch.nn.Linear): | 
|  | 67 | +                assert mtq.utils.is_quantized(module) | 
|  | 68 | + | 
|  | 69 | +    def test_modelopt_memory_usage(self): | 
|  | 70 | +        inputs = self.get_dummy_inputs() | 
|  | 71 | +        inputs = { | 
|  | 72 | +            k: v.to(device=torch_device, dtype=torch.bfloat16) | 
|  | 73 | +            for k, v in inputs.items() | 
|  | 74 | +            if not isinstance(v, bool) | 
|  | 75 | +        } | 
|  | 76 | + | 
|  | 77 | +        unquantized_model = self.model_cls.from_pretrained( | 
|  | 78 | +            self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer" | 
|  | 79 | +        ) | 
|  | 80 | +        unquantized_model.to(torch_device) | 
|  | 81 | +        unquantized_model_memory = get_memory_consumption_stat( | 
|  | 82 | +            unquantized_model, inputs | 
|  | 83 | +        ) | 
|  | 84 | + | 
|  | 85 | +        quantized_model = self.model_cls.from_pretrained( | 
|  | 86 | +            **self.get_dummy_model_init_kwargs() | 
|  | 87 | +        ) | 
|  | 88 | +        quantized_model.to(torch_device) | 
|  | 89 | +        quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) | 
|  | 90 | + | 
|  | 91 | +        assert ( | 
|  | 92 | +            unquantized_model_memory / quantized_model_memory | 
|  | 93 | +            >= self.expected_memory_reduction | 
|  | 94 | +        ) | 
|  | 95 | + | 
|  | 96 | +    def test_keep_modules_in_fp32(self): | 
|  | 97 | +        _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules | 
|  | 98 | +        self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module | 
|  | 99 | + | 
|  | 100 | +        model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | 
|  | 101 | +        model.to(torch_device) | 
|  | 102 | + | 
|  | 103 | +        for name, module in model.named_modules(): | 
|  | 104 | +            if isinstance(module, torch.nn.Linear): | 
|  | 105 | +                if name in model._keep_in_fp32_modules: | 
|  | 106 | +                    assert module.weight.dtype == torch.float32 | 
|  | 107 | +        self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules | 
|  | 108 | + | 
|  | 109 | +    def test_modules_to_not_convert(self): | 
|  | 110 | +        init_kwargs = self.get_dummy_model_init_kwargs() | 
|  | 111 | +        quantization_config_kwargs = self.get_dummy_init_kwargs() | 
|  | 112 | +        quantization_config_kwargs.update( | 
|  | 113 | +            {"modules_to_not_convert": self.modules_to_not_convert} | 
|  | 114 | +        ) | 
|  | 115 | +        quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs) | 
|  | 116 | +        init_kwargs.update({"quantization_config": quantization_config}) | 
|  | 117 | + | 
|  | 118 | +        model = self.model_cls.from_pretrained(**init_kwargs) | 
|  | 119 | +        model.to(torch_device) | 
|  | 120 | + | 
|  | 121 | +        for name, module in model.named_modules(): | 
|  | 122 | +            if name in self.modules_to_not_convert: | 
|  | 123 | +                assert not mtq.utils.is_quantized(module) | 
|  | 124 | + | 
|  | 125 | +    def test_dtype_assignment(self): | 
|  | 126 | +        model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | 
|  | 127 | + | 
|  | 128 | +        with self.assertRaises(ValueError): | 
|  | 129 | +            model.to(torch.float16) | 
|  | 130 | + | 
|  | 131 | +        with self.assertRaises(ValueError): | 
|  | 132 | +            device_0 = f"{torch_device}:0" | 
|  | 133 | +            model.to(device=device_0, dtype=torch.float16) | 
|  | 134 | + | 
|  | 135 | +        with self.assertRaises(ValueError): | 
|  | 136 | +            model.float() | 
|  | 137 | + | 
|  | 138 | +        with self.assertRaises(ValueError): | 
|  | 139 | +            model.half() | 
|  | 140 | + | 
|  | 141 | +        model.to(torch_device) | 
|  | 142 | + | 
|  | 143 | +    def test_serialization(self): | 
|  | 144 | +        model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | 
|  | 145 | +        inputs = self.get_dummy_inputs() | 
|  | 146 | + | 
|  | 147 | +        model.to(torch_device) | 
|  | 148 | +        with torch.no_grad(): | 
|  | 149 | +            model_output = model(**inputs) | 
|  | 150 | + | 
|  | 151 | +        with tempfile.TemporaryDirectory() as tmp_dir: | 
|  | 152 | +            model.save_pretrained(tmp_dir) | 
|  | 153 | +            saved_model = self.model_cls.from_pretrained( | 
|  | 154 | +                tmp_dir, | 
|  | 155 | +                torch_dtype=torch.bfloat16, | 
|  | 156 | +            ) | 
|  | 157 | + | 
|  | 158 | +        saved_model.to(torch_device) | 
|  | 159 | +        with torch.no_grad(): | 
|  | 160 | +            saved_model_output = saved_model(**inputs) | 
|  | 161 | + | 
|  | 162 | +        assert torch.allclose( | 
|  | 163 | +            model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5 | 
|  | 164 | +        ) | 
|  | 165 | + | 
|  | 166 | +    def test_torch_compile(self): | 
|  | 167 | +        if not self._test_torch_compile: | 
|  | 168 | +            return | 
|  | 169 | + | 
|  | 170 | +        model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | 
|  | 171 | +        compiled_model = torch.compile( | 
|  | 172 | +            model, mode="max-autotune", fullgraph=True, dynamic=False | 
|  | 173 | +        ) | 
|  | 174 | + | 
|  | 175 | +        model.to(torch_device) | 
|  | 176 | +        with torch.no_grad(): | 
|  | 177 | +            model_output = model(**self.get_dummy_inputs()).sample | 
|  | 178 | + | 
|  | 179 | +        compiled_model.to(torch_device) | 
|  | 180 | +        with torch.no_grad(): | 
|  | 181 | +            compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample | 
|  | 182 | + | 
|  | 183 | +        model_output = model_output.detach().float().cpu().numpy() | 
|  | 184 | +        compiled_model_output = compiled_model_output.detach().float().cpu().numpy() | 
|  | 185 | + | 
|  | 186 | +        max_diff = numpy_cosine_similarity_distance( | 
|  | 187 | +            model_output.flatten(), compiled_model_output.flatten() | 
|  | 188 | +        ) | 
|  | 189 | +        assert max_diff < 1e-3 | 
|  | 190 | + | 
|  | 191 | +    def test_device_map_error(self): | 
|  | 192 | +        with self.assertRaises(ValueError): | 
|  | 193 | +            _ = self.model_cls.from_pretrained( | 
|  | 194 | +                **self.get_dummy_model_init_kwargs(), | 
|  | 195 | +                device_map={0: "8GB", "cpu": "16GB"}, | 
|  | 196 | +            ) | 
|  | 197 | + | 
|  | 198 | +    def get_dummy_inputs(self): | 
|  | 199 | +        batch_size = 1 | 
|  | 200 | +        seq_len = 16 | 
|  | 201 | +        height = width = 32 | 
|  | 202 | +        num_latent_channels = 4 | 
|  | 203 | +        caption_channels = 8 | 
|  | 204 | + | 
|  | 205 | +        torch.manual_seed(0) | 
|  | 206 | +        hidden_states = torch.randn( | 
|  | 207 | +            (batch_size, num_latent_channels, height, width) | 
|  | 208 | +        ).to(torch_device, dtype=torch.bfloat16) | 
|  | 209 | +        encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to( | 
|  | 210 | +            torch_device, dtype=torch.bfloat16 | 
|  | 211 | +        ) | 
|  | 212 | +        timestep = ( | 
|  | 213 | +            torch.tensor([1.0]) | 
|  | 214 | +            .to(torch_device, dtype=torch.bfloat16) | 
|  | 215 | +            .expand(batch_size) | 
|  | 216 | +        ) | 
|  | 217 | + | 
|  | 218 | +        return { | 
|  | 219 | +            "hidden_states": hidden_states, | 
|  | 220 | +            "encoder_hidden_states": encoder_hidden_states, | 
|  | 221 | +            "timestep": timestep, | 
|  | 222 | +        } | 
|  | 223 | + | 
|  | 224 | +    def test_model_cpu_offload(self): | 
|  | 225 | +        init_kwargs = self.get_dummy_init_kwargs() | 
|  | 226 | +        transformer = self.model_cls.from_pretrained( | 
|  | 227 | +            self.model_id, | 
|  | 228 | +            quantization_config=NVIDIAModelOptConfig(**init_kwargs), | 
|  | 229 | +            subfolder="transformer", | 
|  | 230 | +            torch_dtype=torch.bfloat16, | 
|  | 231 | +        ) | 
|  | 232 | +        pipe = self.pipeline_cls.from_pretrained( | 
|  | 233 | +            self.model_id, transformer=transformer, torch_dtype=torch.bfloat16 | 
|  | 234 | +        ) | 
|  | 235 | +        pipe.enable_model_cpu_offload(device=torch_device) | 
|  | 236 | +        _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) | 
|  | 237 | + | 
|  | 238 | +    def test_training(self): | 
|  | 239 | +        quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()) | 
|  | 240 | +        quantized_model = self.model_cls.from_pretrained( | 
|  | 241 | +            self.model_id, | 
|  | 242 | +            subfolder="transformer", | 
|  | 243 | +            quantization_config=quantization_config, | 
|  | 244 | +            torch_dtype=torch.bfloat16, | 
|  | 245 | +        ).to(torch_device) | 
|  | 246 | + | 
|  | 247 | +        for param in quantized_model.parameters(): | 
|  | 248 | +            param.requires_grad = False | 
|  | 249 | +            if param.ndim == 1: | 
|  | 250 | +                param.data = param.data.to(torch.float32) | 
|  | 251 | + | 
|  | 252 | +        for _, module in quantized_model.named_modules(): | 
|  | 253 | +            if hasattr(module, "to_q"): | 
|  | 254 | +                module.to_q = LoRALayer(module.to_q, rank=4) | 
|  | 255 | +            if hasattr(module, "to_k"): | 
|  | 256 | +                module.to_k = LoRALayer(module.to_k, rank=4) | 
|  | 257 | +            if hasattr(module, "to_v"): | 
|  | 258 | +                module.to_v = LoRALayer(module.to_v, rank=4) | 
|  | 259 | + | 
|  | 260 | +        with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): | 
|  | 261 | +            inputs = self.get_dummy_inputs() | 
|  | 262 | +            output = quantized_model(**inputs)[0] | 
|  | 263 | +            output.norm().backward() | 
|  | 264 | + | 
|  | 265 | +        for module in quantized_model.modules(): | 
|  | 266 | +            if isinstance(module, LoRALayer): | 
|  | 267 | +                self.assertTrue(module.adapter[1].weight.grad is not None) | 
|  | 268 | + | 
|  | 269 | + | 
|  | 270 | +class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): | 
|  | 271 | +    expected_memory_reduction = 0.6 | 
|  | 272 | + | 
|  | 273 | +    def get_dummy_init_kwargs(self): | 
|  | 274 | +        return {"quant_type": "FP8"} | 
|  | 275 | + | 
|  | 276 | + | 
|  | 277 | +class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): | 
|  | 278 | +    expected_memory_reduction = 0.6 | 
|  | 279 | +    _test_torch_compile = True | 
|  | 280 | + | 
|  | 281 | +    def get_dummy_init_kwargs(self): | 
|  | 282 | +        return {"quant_type": "INT8"} | 
|  | 283 | + | 
|  | 284 | + | 
|  | 285 | +@require_torch_cuda_compatibility(8.0) | 
|  | 286 | +class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): | 
|  | 287 | +    expected_memory_reduction = 0.55 | 
|  | 288 | + | 
|  | 289 | +    def get_dummy_init_kwargs(self): | 
|  | 290 | +        return { | 
|  | 291 | +            "quant_type": "INT4", | 
|  | 292 | +            "block_quantize": 128, | 
|  | 293 | +            "channel_quantize": -1, | 
|  | 294 | +            "modules_to_not_convert": ["conv", "patch_embed"], | 
|  | 295 | +        } | 
|  | 296 | + | 
|  | 297 | + | 
|  | 298 | +@require_torch_cuda_compatibility(8.0) | 
|  | 299 | +class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): | 
|  | 300 | +    expected_memory_reduction = 0.65 | 
|  | 301 | + | 
|  | 302 | +    def get_dummy_init_kwargs(self): | 
|  | 303 | +        return { | 
|  | 304 | +            "quant_type": "NF4", | 
|  | 305 | +            "block_quantize": 128, | 
|  | 306 | +            "channel_quantize": -1, | 
|  | 307 | +            "scale_block_quantize": 8, | 
|  | 308 | +            "scale_channel_quantize": -1, | 
|  | 309 | +            "modules_to_not_convert": ["conv"], | 
|  | 310 | +        } | 
|  | 311 | + | 
|  | 312 | + | 
|  | 313 | +@require_torch_cuda_compatibility(8.0) | 
|  | 314 | +class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): | 
|  | 315 | +    expected_memory_reduction = 0.65 | 
|  | 316 | + | 
|  | 317 | +    def get_dummy_init_kwargs(self): | 
|  | 318 | +        return { | 
|  | 319 | +            "quant_type": "NVFP4", | 
|  | 320 | +            "block_quantize": 128, | 
|  | 321 | +            "channel_quantize": -1, | 
|  | 322 | +            "modules_to_not_convert": ["conv"], | 
|  | 323 | +        } | 
0 commit comments