Skip to content

Commit 670202d

Browse files
committed
update
1 parent d48835d commit 670202d

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed

tests/quantization/modelopt/__init__.py

Whitespace-only changes.
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
import gc
2+
import tempfile
3+
import unittest
4+
5+
from diffusers import SanaPipeline, SanaTransformer2DModel, NVIDIAModelOptConfig
6+
from diffusers.utils import is_nvidia_modelopt_available, is_torch_available
7+
from diffusers.utils.testing_utils import (
8+
backend_empty_cache,
9+
backend_reset_peak_memory_stats,
10+
enable_full_determinism,
11+
nightly,
12+
numpy_cosine_similarity_distance,
13+
require_accelerate,
14+
require_big_accelerator,
15+
require_torch_cuda_compatibility,
16+
torch_device,
17+
)
18+
19+
if is_nvidia_modelopt_available():
20+
import modelopt.torch.quantization as mtq
21+
22+
if is_torch_available():
23+
import torch
24+
from ..utils import LoRALayer, get_memory_consumption_stat
25+
26+
enable_full_determinism()
27+
28+
29+
@nightly
30+
@require_big_accelerator
31+
@require_accelerate
32+
class ModelOptBaseTesterMixin:
33+
model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
34+
model_cls = SanaTransformer2DModel
35+
pipeline_cls = SanaPipeline
36+
torch_dtype = torch.bfloat16
37+
expected_memory_reduction = 0.0
38+
keep_in_fp32_module = ""
39+
modules_to_not_convert = ""
40+
_test_torch_compile = False
41+
42+
def setUp(self):
43+
backend_reset_peak_memory_stats(torch_device)
44+
backend_empty_cache(torch_device)
45+
gc.collect()
46+
47+
def tearDown(self):
48+
backend_reset_peak_memory_stats(torch_device)
49+
backend_empty_cache(torch_device)
50+
gc.collect()
51+
52+
def get_dummy_init_kwargs(self):
53+
return {"quant_type": "FP8"}
54+
55+
def get_dummy_model_init_kwargs(self):
56+
return {
57+
"pretrained_model_name_or_path": self.model_id,
58+
"torch_dtype": self.torch_dtype,
59+
"quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()),
60+
"subfolder": "transformer",
61+
}
62+
63+
def test_modelopt_layers(self):
64+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
65+
for name, module in model.named_modules():
66+
if isinstance(module, torch.nn.Linear):
67+
assert mtq.utils.is_quantized(module)
68+
69+
def test_modelopt_memory_usage(self):
70+
inputs = self.get_dummy_inputs()
71+
inputs = {
72+
k: v.to(device=torch_device, dtype=torch.bfloat16)
73+
for k, v in inputs.items()
74+
if not isinstance(v, bool)
75+
}
76+
77+
unquantized_model = self.model_cls.from_pretrained(
78+
self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer"
79+
)
80+
unquantized_model.to(torch_device)
81+
unquantized_model_memory = get_memory_consumption_stat(
82+
unquantized_model, inputs
83+
)
84+
85+
quantized_model = self.model_cls.from_pretrained(
86+
**self.get_dummy_model_init_kwargs()
87+
)
88+
quantized_model.to(torch_device)
89+
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
90+
91+
assert (
92+
unquantized_model_memory / quantized_model_memory
93+
>= self.expected_memory_reduction
94+
)
95+
96+
def test_keep_modules_in_fp32(self):
97+
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
98+
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
99+
100+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
101+
model.to(torch_device)
102+
103+
for name, module in model.named_modules():
104+
if isinstance(module, torch.nn.Linear):
105+
if name in model._keep_in_fp32_modules:
106+
assert module.weight.dtype == torch.float32
107+
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
108+
109+
def test_modules_to_not_convert(self):
110+
init_kwargs = self.get_dummy_model_init_kwargs()
111+
quantization_config_kwargs = self.get_dummy_init_kwargs()
112+
quantization_config_kwargs.update(
113+
{"modules_to_not_convert": self.modules_to_not_convert}
114+
)
115+
quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs)
116+
init_kwargs.update({"quantization_config": quantization_config})
117+
118+
model = self.model_cls.from_pretrained(**init_kwargs)
119+
model.to(torch_device)
120+
121+
for name, module in model.named_modules():
122+
if name in self.modules_to_not_convert:
123+
assert not mtq.utils.is_quantized(module)
124+
125+
def test_dtype_assignment(self):
126+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
127+
128+
with self.assertRaises(ValueError):
129+
model.to(torch.float16)
130+
131+
with self.assertRaises(ValueError):
132+
device_0 = f"{torch_device}:0"
133+
model.to(device=device_0, dtype=torch.float16)
134+
135+
with self.assertRaises(ValueError):
136+
model.float()
137+
138+
with self.assertRaises(ValueError):
139+
model.half()
140+
141+
model.to(torch_device)
142+
143+
def test_serialization(self):
144+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
145+
inputs = self.get_dummy_inputs()
146+
147+
model.to(torch_device)
148+
with torch.no_grad():
149+
model_output = model(**inputs)
150+
151+
with tempfile.TemporaryDirectory() as tmp_dir:
152+
model.save_pretrained(tmp_dir)
153+
saved_model = self.model_cls.from_pretrained(
154+
tmp_dir,
155+
torch_dtype=torch.bfloat16,
156+
)
157+
158+
saved_model.to(torch_device)
159+
with torch.no_grad():
160+
saved_model_output = saved_model(**inputs)
161+
162+
assert torch.allclose(
163+
model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5
164+
)
165+
166+
def test_torch_compile(self):
167+
if not self._test_torch_compile:
168+
return
169+
170+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
171+
compiled_model = torch.compile(
172+
model, mode="max-autotune", fullgraph=True, dynamic=False
173+
)
174+
175+
model.to(torch_device)
176+
with torch.no_grad():
177+
model_output = model(**self.get_dummy_inputs()).sample
178+
179+
compiled_model.to(torch_device)
180+
with torch.no_grad():
181+
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
182+
183+
model_output = model_output.detach().float().cpu().numpy()
184+
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
185+
186+
max_diff = numpy_cosine_similarity_distance(
187+
model_output.flatten(), compiled_model_output.flatten()
188+
)
189+
assert max_diff < 1e-3
190+
191+
def test_device_map_error(self):
192+
with self.assertRaises(ValueError):
193+
_ = self.model_cls.from_pretrained(
194+
**self.get_dummy_model_init_kwargs(),
195+
device_map={0: "8GB", "cpu": "16GB"},
196+
)
197+
198+
def get_dummy_inputs(self):
199+
batch_size = 1
200+
seq_len = 16
201+
height = width = 32
202+
num_latent_channels = 4
203+
caption_channels = 8
204+
205+
torch.manual_seed(0)
206+
hidden_states = torch.randn(
207+
(batch_size, num_latent_channels, height, width)
208+
).to(torch_device, dtype=torch.bfloat16)
209+
encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to(
210+
torch_device, dtype=torch.bfloat16
211+
)
212+
timestep = (
213+
torch.tensor([1.0])
214+
.to(torch_device, dtype=torch.bfloat16)
215+
.expand(batch_size)
216+
)
217+
218+
return {
219+
"hidden_states": hidden_states,
220+
"encoder_hidden_states": encoder_hidden_states,
221+
"timestep": timestep,
222+
}
223+
224+
def test_model_cpu_offload(self):
225+
init_kwargs = self.get_dummy_init_kwargs()
226+
transformer = self.model_cls.from_pretrained(
227+
self.model_id,
228+
quantization_config=NVIDIAModelOptConfig(**init_kwargs),
229+
subfolder="transformer",
230+
torch_dtype=torch.bfloat16,
231+
)
232+
pipe = self.pipeline_cls.from_pretrained(
233+
self.model_id, transformer=transformer, torch_dtype=torch.bfloat16
234+
)
235+
pipe.enable_model_cpu_offload(device=torch_device)
236+
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
237+
238+
def test_training(self):
239+
quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs())
240+
quantized_model = self.model_cls.from_pretrained(
241+
self.model_id,
242+
subfolder="transformer",
243+
quantization_config=quantization_config,
244+
torch_dtype=torch.bfloat16,
245+
).to(torch_device)
246+
247+
for param in quantized_model.parameters():
248+
param.requires_grad = False
249+
if param.ndim == 1:
250+
param.data = param.data.to(torch.float32)
251+
252+
for _, module in quantized_model.named_modules():
253+
if hasattr(module, "to_q"):
254+
module.to_q = LoRALayer(module.to_q, rank=4)
255+
if hasattr(module, "to_k"):
256+
module.to_k = LoRALayer(module.to_k, rank=4)
257+
if hasattr(module, "to_v"):
258+
module.to_v = LoRALayer(module.to_v, rank=4)
259+
260+
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
261+
inputs = self.get_dummy_inputs()
262+
output = quantized_model(**inputs)[0]
263+
output.norm().backward()
264+
265+
for module in quantized_model.modules():
266+
if isinstance(module, LoRALayer):
267+
self.assertTrue(module.adapter[1].weight.grad is not None)
268+
269+
270+
class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
271+
expected_memory_reduction = 0.6
272+
273+
def get_dummy_init_kwargs(self):
274+
return {"quant_type": "FP8"}
275+
276+
277+
class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
278+
expected_memory_reduction = 0.6
279+
_test_torch_compile = True
280+
281+
def get_dummy_init_kwargs(self):
282+
return {"quant_type": "INT8"}
283+
284+
285+
@require_torch_cuda_compatibility(8.0)
286+
class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
287+
expected_memory_reduction = 0.55
288+
289+
def get_dummy_init_kwargs(self):
290+
return {
291+
"quant_type": "INT4",
292+
"block_quantize": 128,
293+
"channel_quantize": -1,
294+
"modules_to_not_convert": ["conv", "patch_embed"],
295+
}
296+
297+
298+
@require_torch_cuda_compatibility(8.0)
299+
class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
300+
expected_memory_reduction = 0.65
301+
302+
def get_dummy_init_kwargs(self):
303+
return {
304+
"quant_type": "NF4",
305+
"block_quantize": 128,
306+
"channel_quantize": -1,
307+
"scale_block_quantize": 8,
308+
"scale_channel_quantize": -1,
309+
"modules_to_not_convert": ["conv"],
310+
}
311+
312+
313+
@require_torch_cuda_compatibility(8.0)
314+
class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
315+
expected_memory_reduction = 0.65
316+
317+
def get_dummy_init_kwargs(self):
318+
return {
319+
"quant_type": "NVFP4",
320+
"block_quantize": 128,
321+
"channel_quantize": -1,
322+
"modules_to_not_convert": ["conv"],
323+
}

0 commit comments

Comments
 (0)