Skip to content

Commit 7230bb1

Browse files
committed
refactor: extract _process_tensors_microscale to reduce duplication
Signed-off-by: David Zheng <dqzheng1996@gmail.com>
1 parent c41d426 commit 7230bb1

File tree

1 file changed

+60
-79
lines changed
  • src/llmcompressor/entrypoints/model_free

1 file changed

+60
-79
lines changed

src/llmcompressor/entrypoints/model_free/process.py

Lines changed: 60 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def process_file_microscale_scheme(
122122
123123
:param file_path: safetensors file to process
124124
:param save_path: save path of file with quantized weights
125-
:param scheme: quantization scheme to apply to tensors
125+
:param scheme: microscale quantization scheme (NVFP4, MXFP4)
126126
:param ignore: modules to ignore. Modules ending with "norm" are automatically
127127
ignored
128128
:param device: device used to quantize and compress weights
@@ -135,64 +135,10 @@ def process_file_microscale_scheme(
135135
if converter is not None:
136136
converter.process(tensors)
137137

138-
fused_sets, unmatched_sets = get_fused_names(tensors)
138+
fused_sets, unmatched_sets = get_fused_names(list(tensors.keys()))
139139
assert len(unmatched_sets) <= 0 # should be caught by validate_safetensors_index
140140

141-
fused_name_to_fused_index: dict[str, int] # fused_name -> fused_index
142-
fused_modules: dict[int, dict[str, Module]] # fused_index -> named_modules
143-
144-
fused_name_to_fused_index = {
145-
name: index
146-
for index, matched_set in enumerate(fused_sets)
147-
for name in matched_set.values()
148-
}
149-
fused_modules = defaultdict(dict)
150-
151-
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
152-
validate_weight_for_quantization(tensors[name], scheme, name)
153-
154-
# 1. initialize module with qparams (on device)
155-
module = initialize_quantized_linear(tensors[name], scheme, device)
156-
157-
# 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
158-
calibrate_global_scale(module)
159-
if name in fused_name_to_fused_index:
160-
fused_index = fused_name_to_fused_index[name]
161-
fused_modules[fused_index][name] = module
162-
continue
163-
164-
calibrate_scale_zp(module)
165-
166-
# 3. compress module using qparams
167-
compress_module(module)
168-
169-
# 4. save compressed data (on cpu)
170-
del tensors[name]
171-
prefix = module_name + "."
172-
for key, value in module.state_dict(prefix=prefix).items():
173-
tensors[key] = value.to("cpu")
174-
175-
# compress and save microscale fused modules
176-
for named_modules in fused_modules.values():
177-
# 2.1. fuse global scales
178-
global_scales = [m.weight_global_scale for m in named_modules.values()]
179-
fused_global_scale = torch.min(torch.cat(global_scales, dim=0))
180-
181-
for name, module in named_modules.items():
182-
module_name, _ = name.rsplit(".", 1)
183-
module.weight_global_scale.data.copy_(fused_global_scale)
184-
185-
# 2.2. finish calibration with fused global scales
186-
calibrate_scale_zp(module)
187-
188-
# 3. compress module using microscale qparams
189-
compress_module(module)
190-
191-
# 4. save compressed data (on cpu)
192-
del tensors[name]
193-
prefix = module_name + "."
194-
for key, value in module.state_dict(prefix=prefix).items():
195-
tensors[key] = value.to("cpu")
141+
tensors, _ = _process_tensors_microscale(tensors, scheme, ignore, device)
196142

197143
save_file(tensors, save_path)
198144
total_size = sum(tensor.nbytes for tensor in tensors.values())
@@ -254,6 +200,54 @@ def process_file_group_microscale_scheme(
254200
"This is a bug in group_files_by_fused_weights."
255201
)
256202

203+
tensors, tensor_to_shard = _process_tensors_microscale(
204+
tensors, scheme, ignore, device, tensor_to_shard
205+
)
206+
207+
# Re-shard: write each tensor back to its original output file
208+
output_shards: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
209+
for name, tensor in tensors.items():
210+
output_shards[tensor_to_shard[name]][name] = tensor
211+
212+
total_size = 0
213+
weight_map: dict[str, str] = {}
214+
for save_path in save_paths:
215+
shard_name = os.path.basename(save_path)
216+
shard_tensors = output_shards.get(shard_name, {})
217+
os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
218+
save_file(shard_tensors, save_path)
219+
total_size += sum(t.nbytes for t in shard_tensors.values())
220+
weight_map.update({k: shard_name for k in shard_tensors})
221+
222+
return total_size, weight_map
223+
224+
225+
def _process_tensors_microscale(
226+
tensors: dict[str, torch.Tensor],
227+
scheme: QuantizationScheme,
228+
ignore: Iterable[str],
229+
device: str | torch.device,
230+
tensor_to_shard: dict[str, str] | None = None,
231+
) -> tuple[dict[str, torch.Tensor], dict[str, str] | None]:
232+
"""
233+
Core microscale quantization logic shared by process_file_microscale_scheme
234+
and process_file_group_microscale_scheme.
235+
236+
Processes all quantizable tensors in the given dict in-place, handling
237+
global scale fusion for fused weight sets (q/k/v, gate/up). When
238+
tensor_to_shard is provided, shard assignments are updated to follow
239+
compressed tensor keys.
240+
241+
:param tensors: dict of tensor name -> tensor, modified in-place
242+
:param scheme: microscale quantization scheme (NVFP4, MXFP4)
243+
:param ignore: modules to ignore
244+
:param device: device used to quantize and compress weights
245+
:param tensor_to_shard: optional mapping of tensor name -> shard filename,
246+
updated in-place when compressed tensors produce new keys
247+
:return: (tensors, tensor_to_shard) tuple with updated contents
248+
"""
249+
fused_sets, _ = get_fused_names(list(tensors.keys()))
250+
257251
fused_name_to_fused_index: dict[str, int] = {
258252
name: index
259253
for index, matched_set in enumerate(fused_sets)
@@ -280,13 +274,14 @@ def process_file_group_microscale_scheme(
280274
# 3. compress module using qparams
281275
compress_module(module)
282276

283-
# 4. save compressed data back to cpu, preserving shard assignment
284-
original_shard = tensor_to_shard[name]
277+
# 4. save compressed data back to cpu
278+
original_shard = tensor_to_shard[name] if tensor_to_shard else None
285279
del tensors[name]
286280
prefix = module_name + "."
287281
for key, value in module.state_dict(prefix=prefix).items():
288282
tensors[key] = value.to("cpu")
289-
tensor_to_shard[key] = original_shard
283+
if tensor_to_shard is not None:
284+
tensor_to_shard[key] = original_shard
290285

291286
# compress and save microscale fused modules (with fused global scales)
292287
for named_modules in fused_modules.values():
@@ -304,27 +299,13 @@ def process_file_group_microscale_scheme(
304299
# 3. compress module using microscale qparams
305300
compress_module(module)
306301

307-
# 4. save compressed data back to cpu, preserving shard assignment
308-
original_shard = tensor_to_shard[name]
302+
# 4. save compressed data back to cpu
303+
original_shard = tensor_to_shard[name] if tensor_to_shard else None
309304
del tensors[name]
310305
prefix = module_name + "."
311306
for key, value in module.state_dict(prefix=prefix).items():
312307
tensors[key] = value.to("cpu")
313-
tensor_to_shard[key] = original_shard
314-
315-
# Re-shard: write each tensor back to its original output file
316-
output_shards: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
317-
for name, tensor in tensors.items():
318-
output_shards[tensor_to_shard[name]][name] = tensor
319-
320-
total_size = 0
321-
weight_map: dict[str, str] = {}
322-
for save_path in save_paths:
323-
shard_name = os.path.basename(save_path)
324-
shard_tensors = output_shards.get(shard_name, {})
325-
os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
326-
save_file(shard_tensors, save_path)
327-
total_size += sum(t.nbytes for t in shard_tensors.values())
328-
weight_map.update({k: shard_name for k in shard_tensors})
308+
if tensor_to_shard is not None:
309+
tensor_to_shard[key] = original_shard
329310

330-
return total_size, weight_map
311+
return tensors, tensor_to_shard

0 commit comments

Comments
 (0)