33
44import torch
55
6- from diffusers import FluxTransformer2DModel , GGUFQuantizationConfig
6+ from diffusers import FluxTransformer2DModel , GGUFQuantizationConfig , SD3Transformer2DModel
77from diffusers .utils .testing_utils import (
88 is_gguf_available ,
99 nightly ,
2222@require_big_gpu_with_torch_cuda
2323@require_accelerate
2424@require_gguf_version_greater_or_equal ("0.10.0" )
25- class GGUFSingleFileTests (unittest .TestCase ):
26- ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
25+ class GGUFSingleFileTesterMixin :
26+ ckpt_path = None
27+ model_cls = None
2728 torch_dtype = torch .bfloat16
28-
29- def setUp (self ):
30- gc .collect ()
31- torch .cuda .empty_cache ()
32-
33- def tearDown (self ):
34- gc .collect ()
35- torch .cuda .empty_cache ()
36-
37- def get_dummy_inputs (self ):
38- return {
39- "hidden_states" : torch .randn ((1 , 4096 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
40- torch_device , self .torch_dtype
41- ),
42- "encoder_hidden_states" : torch .randn (
43- (1 , 512 , 4096 ),
44- generator = torch .Generator ("cpu" ).manual_seed (0 ),
45- ).to (torch_device , self .torch_dtype ),
46- "pooled_projections" : torch .randn (
47- (1 , 768 ),
48- generator = torch .Generator ("cpu" ).manual_seed (0 ),
49- ).to (torch_device , self .torch_dtype ),
50- "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
51- "img_ids" : torch .randn ((4096 , 3 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
52- torch_device , self .torch_dtype
53- ),
54- "txt_ids" : torch .randn ((512 , 3 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
55- torch_device , self .torch_dtype
56- ),
57- "guidance" : torch .tensor ([3.5 ]).to (torch_device , self .torch_dtype ),
58- }
29+ expected_memory_use_in_gb = 5
5930
6031 def test_gguf_parameters (self ):
6132 quant_storage_type = torch .uint8
6233 quantization_config = GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
63- model = FluxTransformer2DModel .from_single_file (self .ckpt_path , quantization_config = quantization_config )
34+ model = self . model_cls .from_single_file (self .ckpt_path , quantization_config = quantization_config )
6435
6536 for param_name , param in model .named_parameters ():
6637 if isinstance (param , GGUFParameter ):
@@ -69,7 +40,7 @@ def test_gguf_parameters(self):
6940
7041 def test_gguf_linear_layers (self ):
7142 quantization_config = GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
72- model = FluxTransformer2DModel .from_single_file (self .ckpt_path , quantization_config = quantization_config )
43+ model = self . model_cls .from_single_file (self .ckpt_path , quantization_config = quantization_config )
7344
7445 for name , module in model .named_modules ():
7546 if isinstance (module , torch .nn .Linear ) and hasattr (module .weight , "quant_type" ):
@@ -78,29 +49,29 @@ def test_gguf_linear_layers(self):
7849 def test_gguf_memory_usage (self ):
7950 quantization_config = GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
8051
81- model = FluxTransformer2DModel .from_single_file (
52+ model = self . model_cls .from_single_file (
8253 self .ckpt_path , quantization_config = quantization_config , torch_dtype = self .torch_dtype
8354 )
8455 model .to ("cuda" )
85- assert (model .get_memory_footprint () / 1024 ** 3 ) < 5
56+ assert (model .get_memory_footprint () / 1024 ** 3 ) < self . expected_memory_use_in_gb
8657 inputs = self .get_dummy_inputs ()
8758
8859 torch .cuda .reset_peak_memory_stats ()
8960 torch .cuda .empty_cache ()
9061 with torch .no_grad ():
9162 model (** inputs )
9263 max_memory = torch .cuda .max_memory_allocated ()
93- assert (max_memory / 1024 ** 3 ) < 5
64+ assert (max_memory / 1024 ** 3 ) < self . expected_memory_use_in_gb
9465
9566 def test_keep_modules_in_fp32 (self ):
9667 r"""
9768 A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
9869 Also ensures if inference works.
9970 """
100- FluxTransformer2DModel ._keep_in_fp32_modules = ["proj_out" ]
71+ self . model_cls ._keep_in_fp32_modules = ["proj_out" ]
10172
10273 quantization_config = GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
103- model = FluxTransformer2DModel .from_single_file (self .ckpt_path , quantization_config = quantization_config )
74+ model = self . model_cls .from_single_file (self .ckpt_path , quantization_config = quantization_config )
10475
10576 for name , module in model .named_modules ():
10677 if isinstance (module , torch .nn .Linear ):
@@ -109,7 +80,7 @@ def test_keep_modules_in_fp32(self):
10980
11081 def test_dtype_assignment (self ):
11182 quantization_config = GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
112- model = FluxTransformer2DModel .from_single_file (self .ckpt_path , quantization_config = quantization_config )
83+ model = self . model_cls .from_single_file (self .ckpt_path , quantization_config = quantization_config )
11384
11485 with self .assertRaises (ValueError ):
11586 # Tries with a `dtype`
@@ -129,3 +100,103 @@ def test_dtype_assignment(self):
129100
130101 # This should work
131102 model .to ("cuda" )
103+
104+
105+ class FluxGGUFSingleFileTests (GGUFSingleFileTesterMixin , unittest .TestCase ):
106+ ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
107+ torch_dtype = torch .bfloat16
108+ model_cls = FluxTransformer2DModel
109+ expected_memory_use_in_gb = 5
110+
111+ def setUp (self ):
112+ gc .collect ()
113+ torch .cuda .empty_cache ()
114+
115+ def tearDown (self ):
116+ gc .collect ()
117+ torch .cuda .empty_cache ()
118+
119+ def get_dummy_inputs (self ):
120+ return {
121+ "hidden_states" : torch .randn ((1 , 4096 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
122+ torch_device , self .torch_dtype
123+ ),
124+ "encoder_hidden_states" : torch .randn (
125+ (1 , 512 , 4096 ),
126+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
127+ ).to (torch_device , self .torch_dtype ),
128+ "pooled_projections" : torch .randn (
129+ (1 , 768 ),
130+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
131+ ).to (torch_device , self .torch_dtype ),
132+ "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
133+ "img_ids" : torch .randn ((4096 , 3 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
134+ torch_device , self .torch_dtype
135+ ),
136+ "txt_ids" : torch .randn ((512 , 3 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
137+ torch_device , self .torch_dtype
138+ ),
139+ "guidance" : torch .tensor ([3.5 ]).to (torch_device , self .torch_dtype ),
140+ }
141+
142+
143+ class SD35LargeGGUFSingleFileTests (GGUFSingleFileTesterMixin , unittest .TestCase ):
144+ ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
145+ torch_dtype = torch .bfloat16
146+ model_cls = SD3Transformer2DModel
147+ expected_memory_use_in_gb = 5
148+
149+ def setUp (self ):
150+ gc .collect ()
151+ torch .cuda .empty_cache ()
152+
153+ def tearDown (self ):
154+ gc .collect ()
155+ torch .cuda .empty_cache ()
156+
157+ def get_dummy_inputs (self ):
158+ return {
159+ "hidden_states" : torch .randn ((1 , 16 , 64 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
160+ torch_device , self .torch_dtype
161+ ),
162+ "encoder_hidden_states" : torch .randn (
163+ (1 , 512 , 4096 ),
164+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
165+ ).to (torch_device , self .torch_dtype ),
166+ "pooled_projections" : torch .randn (
167+ (1 , 2048 ),
168+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
169+ ).to (torch_device , self .torch_dtype ),
170+ "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
171+ }
172+
173+
174+ class SD35MediumGGUFSingleFileTests (GGUFSingleFileTesterMixin , unittest .TestCase ):
175+ ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
176+ torch_dtype = torch .bfloat16
177+ model_cls = SD3Transformer2DModel
178+ expected_memory_use_in_gb = 2
179+
180+ def setUp (self ):
181+ gc .collect ()
182+ torch .cuda .empty_cache ()
183+
184+ def tearDown (self ):
185+ gc .collect ()
186+ torch .cuda .empty_cache ()
187+
188+ def get_dummy_inputs (self ):
189+ return {
190+ "hidden_states" : torch .randn ((1 , 16 , 64 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
191+ torch_device , self .torch_dtype
192+ ),
193+ "encoder_hidden_states" : torch .randn (
194+ (1 , 512 , 4096 ),
195+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
196+ ).to (torch_device , self .torch_dtype ),
197+ "pooled_projections" : torch .randn (
198+ (1 , 2048 ),
199+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
200+ ).to (torch_device , self .torch_dtype ),
201+ "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
202+ }
0 commit comments