Skip to content

Commit e9303a0

Browse files
committed
update
1 parent 9310035 commit e9303a0

File tree

2 files changed

+132
-48
lines changed

2 files changed

+132
-48
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@
8181
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
8282
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
8383
"stable_cascade_stage_c": "clip_txt_mapper.weight",
84-
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
85-
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
84+
"sd3": [
85+
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
86+
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
87+
],
88+
"sd35_large": [
89+
"joint_blocks.37.x_block.mlp.fc1.weight",
90+
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
91+
],
8692
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
8793
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
8894
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -529,13 +535,20 @@ def infer_diffusers_model_type(checkpoint):
529535
):
530536
model_type = "stable_cascade_stage_b"
531537

532-
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
533-
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
538+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
539+
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
540+
):
541+
if "model.diffusion_model.pos_embed" in checkpoint:
542+
key = "model.diffusion_model.pos_embed"
543+
else:
544+
key = "pos_embed"
545+
546+
if checkpoint[key].shape[1] == 36864:
534547
model_type = "sd3"
535-
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
548+
elif checkpoint[key].shape[1] == 147456:
536549
model_type = "sd35_medium"
537550

538-
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
551+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
539552
model_type = "sd35_large"
540553

541554
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:

tests/quantization/gguf/test_gguf.py

Lines changed: 113 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from diffusers import FluxTransformer2DModel, GGUFQuantizationConfig
6+
from diffusers import FluxTransformer2DModel, GGUFQuantizationConfig, SD3Transformer2DModel
77
from diffusers.utils.testing_utils import (
88
is_gguf_available,
99
nightly,
@@ -22,45 +22,16 @@
2222
@require_big_gpu_with_torch_cuda
2323
@require_accelerate
2424
@require_gguf_version_greater_or_equal("0.10.0")
25-
class GGUFSingleFileTests(unittest.TestCase):
26-
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
25+
class GGUFSingleFileTesterMixin:
26+
ckpt_path = None
27+
model_cls = None
2728
torch_dtype = torch.bfloat16
28-
29-
def setUp(self):
30-
gc.collect()
31-
torch.cuda.empty_cache()
32-
33-
def tearDown(self):
34-
gc.collect()
35-
torch.cuda.empty_cache()
36-
37-
def get_dummy_inputs(self):
38-
return {
39-
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
40-
torch_device, self.torch_dtype
41-
),
42-
"encoder_hidden_states": torch.randn(
43-
(1, 512, 4096),
44-
generator=torch.Generator("cpu").manual_seed(0),
45-
).to(torch_device, self.torch_dtype),
46-
"pooled_projections": torch.randn(
47-
(1, 768),
48-
generator=torch.Generator("cpu").manual_seed(0),
49-
).to(torch_device, self.torch_dtype),
50-
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
51-
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
52-
torch_device, self.torch_dtype
53-
),
54-
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
55-
torch_device, self.torch_dtype
56-
),
57-
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
58-
}
29+
expected_memory_use_in_gb = 5
5930

6031
def test_gguf_parameters(self):
6132
quant_storage_type = torch.uint8
6233
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
63-
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
34+
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
6435

6536
for param_name, param in model.named_parameters():
6637
if isinstance(param, GGUFParameter):
@@ -69,7 +40,7 @@ def test_gguf_parameters(self):
6940

7041
def test_gguf_linear_layers(self):
7142
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
72-
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
43+
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
7344

7445
for name, module in model.named_modules():
7546
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
@@ -78,29 +49,29 @@ def test_gguf_linear_layers(self):
7849
def test_gguf_memory_usage(self):
7950
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
8051

81-
model = FluxTransformer2DModel.from_single_file(
52+
model = self.model_cls.from_single_file(
8253
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
8354
)
8455
model.to("cuda")
85-
assert (model.get_memory_footprint() / 1024**3) < 5
56+
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
8657
inputs = self.get_dummy_inputs()
8758

8859
torch.cuda.reset_peak_memory_stats()
8960
torch.cuda.empty_cache()
9061
with torch.no_grad():
9162
model(**inputs)
9263
max_memory = torch.cuda.max_memory_allocated()
93-
assert (max_memory / 1024**3) < 5
64+
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
9465

9566
def test_keep_modules_in_fp32(self):
9667
r"""
9768
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
9869
Also ensures if inference works.
9970
"""
100-
FluxTransformer2DModel._keep_in_fp32_modules = ["proj_out"]
71+
self.model_cls._keep_in_fp32_modules = ["proj_out"]
10172

10273
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
103-
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
74+
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
10475

10576
for name, module in model.named_modules():
10677
if isinstance(module, torch.nn.Linear):
@@ -109,7 +80,7 @@ def test_keep_modules_in_fp32(self):
10980

11081
def test_dtype_assignment(self):
11182
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
112-
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
83+
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
11384

11485
with self.assertRaises(ValueError):
11586
# Tries with a `dtype`
@@ -129,3 +100,103 @@ def test_dtype_assignment(self):
129100

130101
# This should work
131102
model.to("cuda")
103+
104+
105+
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
106+
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
107+
torch_dtype = torch.bfloat16
108+
model_cls = FluxTransformer2DModel
109+
expected_memory_use_in_gb = 5
110+
111+
def setUp(self):
112+
gc.collect()
113+
torch.cuda.empty_cache()
114+
115+
def tearDown(self):
116+
gc.collect()
117+
torch.cuda.empty_cache()
118+
119+
def get_dummy_inputs(self):
120+
return {
121+
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
122+
torch_device, self.torch_dtype
123+
),
124+
"encoder_hidden_states": torch.randn(
125+
(1, 512, 4096),
126+
generator=torch.Generator("cpu").manual_seed(0),
127+
).to(torch_device, self.torch_dtype),
128+
"pooled_projections": torch.randn(
129+
(1, 768),
130+
generator=torch.Generator("cpu").manual_seed(0),
131+
).to(torch_device, self.torch_dtype),
132+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
133+
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
134+
torch_device, self.torch_dtype
135+
),
136+
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
137+
torch_device, self.torch_dtype
138+
),
139+
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
140+
}
141+
142+
143+
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
144+
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
145+
torch_dtype = torch.bfloat16
146+
model_cls = SD3Transformer2DModel
147+
expected_memory_use_in_gb = 5
148+
149+
def setUp(self):
150+
gc.collect()
151+
torch.cuda.empty_cache()
152+
153+
def tearDown(self):
154+
gc.collect()
155+
torch.cuda.empty_cache()
156+
157+
def get_dummy_inputs(self):
158+
return {
159+
"hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
160+
torch_device, self.torch_dtype
161+
),
162+
"encoder_hidden_states": torch.randn(
163+
(1, 512, 4096),
164+
generator=torch.Generator("cpu").manual_seed(0),
165+
).to(torch_device, self.torch_dtype),
166+
"pooled_projections": torch.randn(
167+
(1, 2048),
168+
generator=torch.Generator("cpu").manual_seed(0),
169+
).to(torch_device, self.torch_dtype),
170+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
171+
}
172+
173+
174+
class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
175+
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
176+
torch_dtype = torch.bfloat16
177+
model_cls = SD3Transformer2DModel
178+
expected_memory_use_in_gb = 2
179+
180+
def setUp(self):
181+
gc.collect()
182+
torch.cuda.empty_cache()
183+
184+
def tearDown(self):
185+
gc.collect()
186+
torch.cuda.empty_cache()
187+
188+
def get_dummy_inputs(self):
189+
return {
190+
"hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
191+
torch_device, self.torch_dtype
192+
),
193+
"encoder_hidden_states": torch.randn(
194+
(1, 512, 4096),
195+
generator=torch.Generator("cpu").manual_seed(0),
196+
).to(torch_device, self.torch_dtype),
197+
"pooled_projections": torch.randn(
198+
(1, 2048),
199+
generator=torch.Generator("cpu").manual_seed(0),
200+
).to(torch_device, self.torch_dtype),
201+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
202+
}

0 commit comments

Comments
 (0)