Skip to content

Commit 56a681d

Browse files
authored
Revert "Refactor module / parameter matching logic (#406)" (#429)
This reverts commit d99d38b.
1 parent d99d38b commit 56a681d

File tree

5 files changed

+355
-314
lines changed

5 files changed

+355
-314
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 138 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
apply_quantization_config,
4343
load_pretrained_quantization_parameters,
4444
)
45+
from compressed_tensors.quantization.lifecycle import expand_target_names
4546
from compressed_tensors.quantization.utils import is_module_quantized
4647
from compressed_tensors.transform import TransformConfig
4748
from compressed_tensors.utils import (
@@ -59,7 +60,6 @@
5960
fix_fsdp_module_name,
6061
is_compressed_tensors_config,
6162
)
62-
from compressed_tensors.utils.match import match_named_modules
6363
from torch import Tensor
6464
from torch.nn import Module
6565
from tqdm import tqdm
@@ -342,15 +342,13 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
342342
self.sparsity_compressor
343343
and self.sparsity_config.format != CompressionFormat.dense.value
344344
):
345-
sparse_targets = match_named_modules(
345+
sparse_targets = expand_target_names(
346346
model=model,
347347
targets=self.sparsity_config.targets,
348348
ignore=self.sparsity_config.ignore,
349349
)
350-
351350
missing_keys.update(
352-
merge_names(target_name, "weight")
353-
for target_name, _module in sparse_targets
351+
merge_names(target, "weight") for target in sparse_targets
354352
)
355353

356354
# Determine missing keys due to pack quantization
@@ -360,14 +358,13 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
360358
== CompressionFormat.pack_quantized.value
361359
):
362360
for scheme in self.quantization_config.config_groups.values():
363-
quant_targets = match_named_modules(
361+
quant_targets = expand_target_names(
364362
model=model,
365363
targets=scheme.targets,
366364
ignore=self.quantization_config.ignore,
367365
)
368366
missing_keys.update(
369-
merge_names(target_name, "weight")
370-
for target_name, _module in quant_targets
367+
merge_names(target, "weight") for target in quant_targets
371368
)
372369

373370
return list(missing_keys)
@@ -398,29 +395,29 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
398395
self.sparsity_compressor
399396
and self.sparsity_config.format != CompressionFormat.dense.value
400397
):
401-
sparse_targets = match_named_modules(
398+
sparse_targets: Set[str] = expand_target_names(
402399
model=model,
403400
targets=self.sparsity_config.targets,
404401
ignore=self.sparsity_config.ignore,
405402
)
406403
unexpected_keys.update(
407-
merge_names(target_name, param)
408-
for target_name, _module in sparse_targets
404+
merge_names(target, param)
405+
for target in sparse_targets
409406
for param in self.sparsity_compressor.compression_param_names
410407
)
411408

412409
# Identify unexpected keys from quantization compression
413410
if self.quantization_compressor:
414411
for scheme in self.quantization_config.config_groups.values():
415-
quant_targets = match_named_modules(
412+
quant_targets: Set[str] = expand_target_names(
416413
model=model,
417414
targets=scheme.targets,
418415
ignore=self.quantization_config.ignore,
419416
)
420417
for quant_compressor in self.quantization_compressor.values():
421418
unexpected_keys.update(
422-
merge_names(target_name, param)
423-
for target_name, _module in quant_targets
419+
merge_names(target, param)
420+
for target in quant_targets
424421
for param in quant_compressor.compression_param_names
425422
if param != "weight"
426423
)
@@ -437,79 +434,73 @@ def compress_model(self, model: Module):
437434
:param model: model containing parameters to compress
438435
"""
439436
module_to_scheme = map_module_to_scheme(model)
440-
sparse_compression_targets = [
441-
module_name
442-
for module_name, _module in match_named_modules(
443-
model=model,
444-
targets=self.sparsity_config.targets if self.sparsity_config else [],
445-
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
446-
)
447-
]
448-
for prefix, module in tqdm(
449-
match_named_modules(
450-
model,
451-
[*sparse_compression_targets, *module_to_scheme.keys()],
452-
warn_on_fail=True,
453-
),
454-
desc="Compressing model",
455-
):
456-
module_device = get_execution_device(module)
457-
is_meta = module_device.type == "meta"
458-
459-
exec_device = "meta" if is_meta else "cpu"
460-
onloading_device = "meta" if is_meta else module_device
461-
462-
# in the future, support compression on same device
463-
with align_module_device(module, execution_device=exec_device):
464-
state_dict = {
465-
f"{prefix}.{name}": param
466-
for name, param in module.named_parameters(recurse=False)
467-
}
468-
469-
# quantization first
470-
if prefix in module_to_scheme:
471-
if (
472-
not hasattr(module.quantization_scheme, "format")
473-
or module.quantization_scheme.format is None
474-
):
475-
if len(self.compression_formats) > 1:
476-
raise ValueError(
477-
"Applying multiple compressors without defining "
478-
"per module formats is not supported "
479-
)
480-
format = self.compression_formats[0]
481-
else:
482-
format = module.quantization_scheme.format
483-
484-
quant_compressor = self.quantization_compressor.get(format)
485-
state_dict = quant_compressor.compress(
486-
state_dict,
487-
names_to_scheme=module_to_scheme,
488-
show_progress=False,
489-
compression_device=exec_device,
490-
)
437+
sparse_compression_targets: Set[str] = expand_target_names(
438+
model=model,
439+
targets=self.sparsity_config.targets if self.sparsity_config else [],
440+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
441+
)
491442

492-
# sparsity second
493-
if prefix in sparse_compression_targets:
494-
state_dict = self.sparsity_compressor.compress(
495-
state_dict,
496-
compression_targets=sparse_compression_targets,
497-
show_progress=False,
498-
)
443+
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
444+
445+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
446+
module_device = get_execution_device(module)
447+
is_meta = module_device.type == "meta"
448+
449+
exec_device = "meta" if is_meta else "cpu"
450+
onloading_device = "meta" if is_meta else module_device
451+
452+
# in the future, support compression on same device
453+
with align_module_device(module, execution_device=exec_device):
454+
state_dict = {
455+
f"{prefix}.{name}": param
456+
for name, param in module.named_parameters(recurse=False)
457+
}
458+
459+
# quantization first
460+
if prefix in module_to_scheme:
461+
if (
462+
not hasattr(module.quantization_scheme, "format")
463+
or module.quantization_scheme.format is None
464+
):
465+
if len(self.compression_formats) > 1:
466+
raise ValueError(
467+
"Applying multiple compressors without defining "
468+
"per module formats is not supported "
469+
)
470+
format = self.compression_formats[0]
471+
else:
472+
format = module.quantization_scheme.format
473+
474+
quant_compressor = self.quantization_compressor.get(format)
475+
state_dict = quant_compressor.compress(
476+
state_dict,
477+
names_to_scheme=module_to_scheme,
478+
show_progress=False,
479+
compression_device=exec_device,
480+
)
481+
482+
# sparsity second
483+
if prefix in sparse_compression_targets:
484+
state_dict = self.sparsity_compressor.compress(
485+
state_dict,
486+
compression_targets=sparse_compression_targets,
487+
show_progress=False,
488+
)
499489

500-
# remove any existing parameters
501-
offload_device = get_offloaded_device(module)
502-
for name, _ in list(module.named_parameters(recurse=False)):
503-
delete_offload_parameter(module, name)
490+
# remove any existing parameters
491+
offload_device = get_offloaded_device(module)
492+
for name, _ in list(module.named_parameters(recurse=False)):
493+
delete_offload_parameter(module, name)
504494

505-
# replace with compressed parameters
506-
for name, value in state_dict.items():
507-
name = name.removeprefix(f"{prefix}.")
508-
value = value.to(onloading_device)
509-
param = torch.nn.Parameter(value, requires_grad=False)
510-
register_offload_parameter(module, name, param, offload_device)
495+
# replace with compressed parameters
496+
for name, value in state_dict.items():
497+
name = name.removeprefix(f"{prefix}.")
498+
value = value.to(onloading_device)
499+
param = torch.nn.Parameter(value, requires_grad=False)
500+
register_offload_parameter(module, name, param, offload_device)
501+
502+
module.quantization_status = QuantizationStatus.COMPRESSED
511503

512-
module.quantization_status = QuantizationStatus.COMPRESSED
513504
# TODO: consider sparse compression to also be compression
514505
if (
515506
self.quantization_config is not None
@@ -525,75 +516,67 @@ def decompress_model(self, model: Module):
525516
:param model: model containing parameters to compress
526517
"""
527518
module_to_scheme = map_module_to_scheme(model)
528-
sparse_compression_targets = [
529-
module_name
530-
for module_name, _module in match_named_modules(
531-
model=model,
532-
targets=self.sparsity_config.targets if self.sparsity_config else [],
533-
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
534-
)
535-
]
536-
537-
for prefix, module in tqdm(
538-
match_named_modules(
539-
model,
540-
[*sparse_compression_targets, *module_to_scheme.keys()],
541-
warn_on_fail=True,
542-
),
543-
desc="Decompressing model",
544-
):
545-
# in the future, support decompression on same device
546-
with align_module_device(module, execution_device="cpu"):
547-
state_dict = {
548-
f"{prefix}.{name}": param
549-
for name, param in module.named_parameters(recurse=False)
550-
}
551-
552-
# sparsity first
553-
if prefix in sparse_compression_targets:
554-
# sparse_compression_targets are automatically inferred by this fn
555-
generator = self.sparsity_compressor.decompress_from_state_dict(
556-
state_dict,
557-
)
558-
# generates (param_path, param_val)
559-
# of compressed and unused params
560-
state_dict = {key: value for key, value in generator}
561-
562-
# quantization second
563-
if prefix in module_to_scheme:
564-
if (
565-
not hasattr(module.quantization_scheme, "format")
566-
or module.quantization_scheme.format is None
567-
):
568-
if len(self.compression_formats) > 1:
569-
raise ValueError(
570-
"Applying multiple compressors without defining "
571-
"per module formats is not supported "
572-
)
573-
format = self.compression_formats[0]
574-
else:
575-
format = module.quantization_scheme.format
576-
quant_compressor = self.quantization_compressor.get(format)
577-
state_dict = quant_compressor.decompress_module_from_state_dict(
578-
prefix,
579-
state_dict,
580-
scheme=module_to_scheme[prefix],
581-
)
519+
sparse_compression_targets: Set[str] = expand_target_names(
520+
model=model,
521+
targets=self.sparsity_config.targets if self.sparsity_config else [],
522+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
523+
)
524+
525+
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
526+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
527+
# in the future, support decompression on same device
528+
with align_module_device(module, execution_device="cpu"):
529+
state_dict = {
530+
f"{prefix}.{name}": param
531+
for name, param in module.named_parameters(recurse=False)
532+
}
533+
534+
# sparsity first
535+
if prefix in sparse_compression_targets:
536+
# sparse_compression_targets are automatically inferred by this fn
537+
generator = self.sparsity_compressor.decompress_from_state_dict(
538+
state_dict,
539+
)
540+
# generates (param_path, param_val)
541+
# of compressed and unused params
542+
state_dict = {key: value for key, value in generator}
543+
544+
# quantization second
545+
if prefix in module_to_scheme:
546+
547+
if (
548+
not hasattr(module.quantization_scheme, "format")
549+
or module.quantization_scheme.format is None
550+
):
551+
if len(self.compression_formats) > 1:
552+
raise ValueError(
553+
"Applying multiple compressors without defining "
554+
"per module formats is not supported "
555+
)
556+
format = self.compression_formats[0]
557+
else:
558+
format = module.quantization_scheme.format
559+
quant_compressor = self.quantization_compressor.get(format)
560+
state_dict = quant_compressor.decompress_module_from_state_dict(
561+
prefix,
562+
state_dict,
563+
scheme=module_to_scheme[prefix],
564+
)
582565

583-
# remove any existing parameters
584-
exec_device = get_execution_device(module)
585-
offload_device = get_offloaded_device(module)
586-
for name, _ in list(module.named_parameters(recurse=False)):
587-
delete_offload_parameter(module, name)
566+
# remove any existing parameters
567+
exec_device = get_execution_device(module)
568+
offload_device = get_offloaded_device(module)
569+
for name, _ in list(module.named_parameters(recurse=False)):
570+
delete_offload_parameter(module, name)
588571

589-
# replace with decompressed parameters
590-
for name, value in state_dict.items():
591-
name = name.removeprefix(f"{prefix}.")
592-
value = value.to(exec_device)
593-
param = torch.nn.Parameter(value, requires_grad=False)
594-
register_offload_parameter(module, name, param, offload_device)
572+
# replace with decompressed parameters
573+
for name, value in state_dict.items():
574+
name = name.removeprefix(f"{prefix}.")
575+
value = value.to(exec_device)
576+
param = torch.nn.Parameter(value, requires_grad=False)
577+
register_offload_parameter(module, name, param, offload_device)
595578

596-
module.quantization_status = QuantizationStatus.FROZEN
579+
module.quantization_status = QuantizationStatus.FROZEN
597580

598581
# ----- state dict compression pathways ----- #
599582

@@ -631,14 +614,11 @@ def compress(
631614
)
632615

633616
if self.sparsity_compressor is not None:
634-
sparse_compression_targets: Set[str] = {
635-
module_name
636-
for module_name, _module in match_named_modules(
637-
model=model,
638-
targets=self.sparsity_config.targets,
639-
ignore=self.sparsity_config.ignore,
640-
)
641-
}
617+
sparse_compression_targets: Set[str] = expand_target_names(
618+
model=model,
619+
targets=self.sparsity_config.targets,
620+
ignore=self.sparsity_config.ignore,
621+
)
642622
state_dict = self.sparsity_compressor.compress(
643623
state_dict,
644624
compression_targets=sparse_compression_targets,
@@ -703,6 +683,7 @@ def decompress(self, model_path: str, model: Module):
703683
with override_quantization_status(
704684
self.quantization_config, QuantizationStatus.FROZEN
705685
):
686+
706687
names_to_scheme = apply_quantization_config(
707688
model, self.quantization_config
708689
)

0 commit comments

Comments
 (0)