Skip to content

Commit d99d38b

Browse files
authored
Refactor module / parameter matching logic (#406)
* Update `apply_quantiation_config` to use `match_named_modules` Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Refactor usages of `expand_target_names`, `is_target`, and `find_name_or_class_matches` Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Small fixes Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Simplify signature of `match_named_modules` Removed `yield_matched_targets` and `warn_on_unmatched_ignores` and updated rest of code Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Ensure `match_targets` doesn't return duplicates Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Remove `preprocess_name` parameter from `match_named_modules` Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Update match.py util fn signatures and small fixes Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Restore `find_name_or_class_matches` as a deprecated function This function is currently used by llm-compressor so adding it back with a deprecation warning for now. Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Use deprecated decorator instead of manual deprecation warning Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Update syntax of of optional types * Remove default None target value in match utils --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent b78eb8f commit d99d38b

File tree

5 files changed

+314
-355
lines changed

5 files changed

+314
-355
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 157 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
apply_quantization_config,
4343
load_pretrained_quantization_parameters,
4444
)
45-
from compressed_tensors.quantization.lifecycle import expand_target_names
4645
from compressed_tensors.quantization.utils import is_module_quantized
4746
from compressed_tensors.transform import TransformConfig
4847
from compressed_tensors.utils import (
@@ -60,6 +59,7 @@
6059
fix_fsdp_module_name,
6160
is_compressed_tensors_config,
6261
)
62+
from compressed_tensors.utils.match import match_named_modules
6363
from torch import Tensor
6464
from torch.nn import Module
6565
from tqdm import tqdm
@@ -342,13 +342,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
342342
self.sparsity_compressor
343343
and self.sparsity_config.format != CompressionFormat.dense.value
344344
):
345-
sparse_targets = expand_target_names(
345+
sparse_targets = match_named_modules(
346346
model=model,
347347
targets=self.sparsity_config.targets,
348348
ignore=self.sparsity_config.ignore,
349349
)
350+
350351
missing_keys.update(
351-
merge_names(target, "weight") for target in sparse_targets
352+
merge_names(target_name, "weight")
353+
for target_name, _module in sparse_targets
352354
)
353355

354356
# Determine missing keys due to pack quantization
@@ -358,13 +360,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
358360
== CompressionFormat.pack_quantized.value
359361
):
360362
for scheme in self.quantization_config.config_groups.values():
361-
quant_targets = expand_target_names(
363+
quant_targets = match_named_modules(
362364
model=model,
363365
targets=scheme.targets,
364366
ignore=self.quantization_config.ignore,
365367
)
366368
missing_keys.update(
367-
merge_names(target, "weight") for target in quant_targets
369+
merge_names(target_name, "weight")
370+
for target_name, _module in quant_targets
368371
)
369372

370373
return list(missing_keys)
@@ -395,29 +398,29 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
395398
self.sparsity_compressor
396399
and self.sparsity_config.format != CompressionFormat.dense.value
397400
):
398-
sparse_targets: Set[str] = expand_target_names(
401+
sparse_targets = match_named_modules(
399402
model=model,
400403
targets=self.sparsity_config.targets,
401404
ignore=self.sparsity_config.ignore,
402405
)
403406
unexpected_keys.update(
404-
merge_names(target, param)
405-
for target in sparse_targets
407+
merge_names(target_name, param)
408+
for target_name, _module in sparse_targets
406409
for param in self.sparsity_compressor.compression_param_names
407410
)
408411

409412
# Identify unexpected keys from quantization compression
410413
if self.quantization_compressor:
411414
for scheme in self.quantization_config.config_groups.values():
412-
quant_targets: Set[str] = expand_target_names(
415+
quant_targets = match_named_modules(
413416
model=model,
414417
targets=scheme.targets,
415418
ignore=self.quantization_config.ignore,
416419
)
417420
for quant_compressor in self.quantization_compressor.values():
418421
unexpected_keys.update(
419-
merge_names(target, param)
420-
for target in quant_targets
422+
merge_names(target_name, param)
423+
for target_name, _module in quant_targets
421424
for param in quant_compressor.compression_param_names
422425
if param != "weight"
423426
)
@@ -434,73 +437,79 @@ def compress_model(self, model: Module):
434437
:param model: model containing parameters to compress
435438
"""
436439
module_to_scheme = map_module_to_scheme(model)
437-
sparse_compression_targets: Set[str] = expand_target_names(
438-
model=model,
439-
targets=self.sparsity_config.targets if self.sparsity_config else [],
440-
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
441-
)
442-
443-
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
444-
445-
if prefix in module_to_scheme or prefix in sparse_compression_targets:
446-
module_device = get_execution_device(module)
447-
is_meta = module_device.type == "meta"
448-
449-
exec_device = "meta" if is_meta else "cpu"
450-
onloading_device = "meta" if is_meta else module_device
451-
452-
# in the future, support compression on same device
453-
with align_module_device(module, execution_device=exec_device):
454-
state_dict = {
455-
f"{prefix}.{name}": param
456-
for name, param in module.named_parameters(recurse=False)
457-
}
458-
459-
# quantization first
460-
if prefix in module_to_scheme:
461-
if (
462-
not hasattr(module.quantization_scheme, "format")
463-
or module.quantization_scheme.format is None
464-
):
465-
if len(self.compression_formats) > 1:
466-
raise ValueError(
467-
"Applying multiple compressors without defining "
468-
"per module formats is not supported "
469-
)
470-
format = self.compression_formats[0]
471-
else:
472-
format = module.quantization_scheme.format
473-
474-
quant_compressor = self.quantization_compressor.get(format)
475-
state_dict = quant_compressor.compress(
476-
state_dict,
477-
names_to_scheme=module_to_scheme,
478-
show_progress=False,
479-
compression_device=exec_device,
480-
)
481-
482-
# sparsity second
483-
if prefix in sparse_compression_targets:
484-
state_dict = self.sparsity_compressor.compress(
485-
state_dict,
486-
compression_targets=sparse_compression_targets,
487-
show_progress=False,
488-
)
440+
sparse_compression_targets = [
441+
module_name
442+
for module_name, _module in match_named_modules(
443+
model=model,
444+
targets=self.sparsity_config.targets if self.sparsity_config else [],
445+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
446+
)
447+
]
448+
for prefix, module in tqdm(
449+
match_named_modules(
450+
model,
451+
[*sparse_compression_targets, *module_to_scheme.keys()],
452+
warn_on_fail=True,
453+
),
454+
desc="Compressing model",
455+
):
456+
module_device = get_execution_device(module)
457+
is_meta = module_device.type == "meta"
458+
459+
exec_device = "meta" if is_meta else "cpu"
460+
onloading_device = "meta" if is_meta else module_device
461+
462+
# in the future, support compression on same device
463+
with align_module_device(module, execution_device=exec_device):
464+
state_dict = {
465+
f"{prefix}.{name}": param
466+
for name, param in module.named_parameters(recurse=False)
467+
}
468+
469+
# quantization first
470+
if prefix in module_to_scheme:
471+
if (
472+
not hasattr(module.quantization_scheme, "format")
473+
or module.quantization_scheme.format is None
474+
):
475+
if len(self.compression_formats) > 1:
476+
raise ValueError(
477+
"Applying multiple compressors without defining "
478+
"per module formats is not supported "
479+
)
480+
format = self.compression_formats[0]
481+
else:
482+
format = module.quantization_scheme.format
483+
484+
quant_compressor = self.quantization_compressor.get(format)
485+
state_dict = quant_compressor.compress(
486+
state_dict,
487+
names_to_scheme=module_to_scheme,
488+
show_progress=False,
489+
compression_device=exec_device,
490+
)
489491

490-
# remove any existing parameters
491-
offload_device = get_offloaded_device(module)
492-
for name, _ in list(module.named_parameters(recurse=False)):
493-
delete_offload_parameter(module, name)
492+
# sparsity second
493+
if prefix in sparse_compression_targets:
494+
state_dict = self.sparsity_compressor.compress(
495+
state_dict,
496+
compression_targets=sparse_compression_targets,
497+
show_progress=False,
498+
)
494499

495-
# replace with compressed parameters
496-
for name, value in state_dict.items():
497-
name = name.removeprefix(f"{prefix}.")
498-
value = value.to(onloading_device)
499-
param = torch.nn.Parameter(value, requires_grad=False)
500-
register_offload_parameter(module, name, param, offload_device)
500+
# remove any existing parameters
501+
offload_device = get_offloaded_device(module)
502+
for name, _ in list(module.named_parameters(recurse=False)):
503+
delete_offload_parameter(module, name)
501504

502-
module.quantization_status = QuantizationStatus.COMPRESSED
505+
# replace with compressed parameters
506+
for name, value in state_dict.items():
507+
name = name.removeprefix(f"{prefix}.")
508+
value = value.to(onloading_device)
509+
param = torch.nn.Parameter(value, requires_grad=False)
510+
register_offload_parameter(module, name, param, offload_device)
503511

512+
module.quantization_status = QuantizationStatus.COMPRESSED
504513
# TODO: consider sparse compression to also be compression
505514
if (
506515
self.quantization_config is not None
@@ -516,67 +525,75 @@ def decompress_model(self, model: Module):
516525
:param model: model containing parameters to compress
517526
"""
518527
module_to_scheme = map_module_to_scheme(model)
519-
sparse_compression_targets: Set[str] = expand_target_names(
520-
model=model,
521-
targets=self.sparsity_config.targets if self.sparsity_config else [],
522-
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
523-
)
524-
525-
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
526-
if prefix in module_to_scheme or prefix in sparse_compression_targets:
527-
# in the future, support decompression on same device
528-
with align_module_device(module, execution_device="cpu"):
529-
state_dict = {
530-
f"{prefix}.{name}": param
531-
for name, param in module.named_parameters(recurse=False)
532-
}
533-
534-
# sparsity first
535-
if prefix in sparse_compression_targets:
536-
# sparse_compression_targets are automatically inferred by this fn
537-
generator = self.sparsity_compressor.decompress_from_state_dict(
538-
state_dict,
539-
)
540-
# generates (param_path, param_val)
541-
# of compressed and unused params
542-
state_dict = {key: value for key, value in generator}
543-
544-
# quantization second
545-
if prefix in module_to_scheme:
546-
547-
if (
548-
not hasattr(module.quantization_scheme, "format")
549-
or module.quantization_scheme.format is None
550-
):
551-
if len(self.compression_formats) > 1:
552-
raise ValueError(
553-
"Applying multiple compressors without defining "
554-
"per module formats is not supported "
555-
)
556-
format = self.compression_formats[0]
557-
else:
558-
format = module.quantization_scheme.format
559-
quant_compressor = self.quantization_compressor.get(format)
560-
state_dict = quant_compressor.decompress_module_from_state_dict(
561-
prefix,
562-
state_dict,
563-
scheme=module_to_scheme[prefix],
564-
)
528+
sparse_compression_targets = [
529+
module_name
530+
for module_name, _module in match_named_modules(
531+
model=model,
532+
targets=self.sparsity_config.targets if self.sparsity_config else [],
533+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
534+
)
535+
]
536+
537+
for prefix, module in tqdm(
538+
match_named_modules(
539+
model,
540+
[*sparse_compression_targets, *module_to_scheme.keys()],
541+
warn_on_fail=True,
542+
),
543+
desc="Decompressing model",
544+
):
545+
# in the future, support decompression on same device
546+
with align_module_device(module, execution_device="cpu"):
547+
state_dict = {
548+
f"{prefix}.{name}": param
549+
for name, param in module.named_parameters(recurse=False)
550+
}
551+
552+
# sparsity first
553+
if prefix in sparse_compression_targets:
554+
# sparse_compression_targets are automatically inferred by this fn
555+
generator = self.sparsity_compressor.decompress_from_state_dict(
556+
state_dict,
557+
)
558+
# generates (param_path, param_val)
559+
# of compressed and unused params
560+
state_dict = {key: value for key, value in generator}
561+
562+
# quantization second
563+
if prefix in module_to_scheme:
564+
if (
565+
not hasattr(module.quantization_scheme, "format")
566+
or module.quantization_scheme.format is None
567+
):
568+
if len(self.compression_formats) > 1:
569+
raise ValueError(
570+
"Applying multiple compressors without defining "
571+
"per module formats is not supported "
572+
)
573+
format = self.compression_formats[0]
574+
else:
575+
format = module.quantization_scheme.format
576+
quant_compressor = self.quantization_compressor.get(format)
577+
state_dict = quant_compressor.decompress_module_from_state_dict(
578+
prefix,
579+
state_dict,
580+
scheme=module_to_scheme[prefix],
581+
)
565582

566-
# remove any existing parameters
567-
exec_device = get_execution_device(module)
568-
offload_device = get_offloaded_device(module)
569-
for name, _ in list(module.named_parameters(recurse=False)):
570-
delete_offload_parameter(module, name)
583+
# remove any existing parameters
584+
exec_device = get_execution_device(module)
585+
offload_device = get_offloaded_device(module)
586+
for name, _ in list(module.named_parameters(recurse=False)):
587+
delete_offload_parameter(module, name)
571588

572-
# replace with decompressed parameters
573-
for name, value in state_dict.items():
574-
name = name.removeprefix(f"{prefix}.")
575-
value = value.to(exec_device)
576-
param = torch.nn.Parameter(value, requires_grad=False)
577-
register_offload_parameter(module, name, param, offload_device)
589+
# replace with decompressed parameters
590+
for name, value in state_dict.items():
591+
name = name.removeprefix(f"{prefix}.")
592+
value = value.to(exec_device)
593+
param = torch.nn.Parameter(value, requires_grad=False)
594+
register_offload_parameter(module, name, param, offload_device)
578595

579-
module.quantization_status = QuantizationStatus.FROZEN
596+
module.quantization_status = QuantizationStatus.FROZEN
580597

581598
# ----- state dict compression pathways ----- #
582599

@@ -614,11 +631,14 @@ def compress(
614631
)
615632

616633
if self.sparsity_compressor is not None:
617-
sparse_compression_targets: Set[str] = expand_target_names(
618-
model=model,
619-
targets=self.sparsity_config.targets,
620-
ignore=self.sparsity_config.ignore,
621-
)
634+
sparse_compression_targets: Set[str] = {
635+
module_name
636+
for module_name, _module in match_named_modules(
637+
model=model,
638+
targets=self.sparsity_config.targets,
639+
ignore=self.sparsity_config.ignore,
640+
)
641+
}
622642
state_dict = self.sparsity_compressor.compress(
623643
state_dict,
624644
compression_targets=sparse_compression_targets,
@@ -683,7 +703,6 @@ def decompress(self, model_path: str, model: Module):
683703
with override_quantization_status(
684704
self.quantization_config, QuantizationStatus.FROZEN
685705
):
686-
687706
names_to_scheme = apply_quantization_config(
688707
model, self.quantization_config
689708
)

0 commit comments

Comments
 (0)