Skip to content

Refactor module / parameter matching logic #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
239 changes: 130 additions & 109 deletions src/compressed_tensors/compressors/model_compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
apply_quantization_config,
load_pretrained_quantization_parameters,
)
from compressed_tensors.quantization.lifecycle import expand_target_names
from compressed_tensors.quantization.utils import is_module_quantized
from compressed_tensors.utils.match import match_named_modules
from compressed_tensors.utils import (
align_module_device,
delete_offload_parameter,
Expand Down Expand Up @@ -292,13 +292,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
self.sparsity_compressor
and self.sparsity_config.format != CompressionFormat.dense.value
):
sparse_targets = expand_target_names(
sparse_targets = match_named_modules(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)

missing_keys.update(
merge_names(target, "weight") for target in sparse_targets
merge_names(target_name, "weight")
for target_name, _module in sparse_targets
)

# Determine missing keys due to pack quantization
Expand All @@ -308,13 +310,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
== CompressionFormat.pack_quantized.value
):
for scheme in self.quantization_config.config_groups.values():
quant_targets = expand_target_names(
quant_targets = match_named_modules(
model=model,
targets=scheme.targets,
ignore=self.quantization_config.ignore,
)
missing_keys.update(
merge_names(target, "weight") for target in quant_targets
merge_names(target_name, "weight")
for target_name, _module in quant_targets
)

return list(missing_keys)
Expand Down Expand Up @@ -345,28 +348,28 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
self.sparsity_compressor
and self.sparsity_config.format != CompressionFormat.dense.value
):
sparse_targets: Set[str] = expand_target_names(
sparse_targets = match_named_modules(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
unexpected_keys.update(
merge_names(target, param)
for target in sparse_targets
merge_names(target_name, param)
for target_name, _module in sparse_targets
for param in self.sparsity_compressor.compression_param_names
)

# Identify unexpected keys from quantization compression
if self.quantization_compressor:
for scheme in self.quantization_config.config_groups.values():
quant_targets: Set[str] = expand_target_names(
quant_targets = match_named_modules(
model=model,
targets=scheme.targets,
ignore=self.quantization_config.ignore,
)
unexpected_keys.update(
merge_names(target, param)
for target in quant_targets
merge_names(target_name, param)
for target_name, _module in quant_targets
for param in self.quantization_compressor.compression_param_names
if param != "weight"
)
Expand All @@ -383,58 +386,65 @@ def compress_model(self, model: Module):
:param model: model containing parameters to compress
"""
module_to_scheme = map_module_to_scheme(model)
sparse_compression_targets: Set[str] = expand_target_names(
model=model,
targets=self.sparsity_config.targets if self.sparsity_config else [],
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
)

for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):

if prefix in module_to_scheme or prefix in sparse_compression_targets:
module_device = get_execution_device(module)
is_meta = module_device.type == "meta"

exec_device = "meta" if is_meta else "cpu"
onloading_device = "meta" if is_meta else module_device

# in the future, support compression on same device
with align_module_device(module, execution_device=exec_device):
state_dict = {
f"{prefix}.{name}": param
for name, param in module.named_parameters(recurse=False)
}

# quantization first
if prefix in module_to_scheme:
state_dict = self.quantization_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
compression_device=exec_device,
)
sparse_compression_targets = [
module_name
for module_name, _module in match_named_modules(
model=model,
targets=self.sparsity_config.targets if self.sparsity_config else [],
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
)
]
for prefix, module in tqdm(
match_named_modules(
model,
[*sparse_compression_targets, *module_to_scheme.keys()],
warn_on_fail=True,
),
desc="Compressing model",
):
module_device = get_execution_device(module)
is_meta = module_device.type == "meta"

exec_device = "meta" if is_meta else "cpu"
onloading_device = "meta" if is_meta else module_device

# in the future, support compression on same device
with align_module_device(module, execution_device=exec_device):
state_dict = {
f"{prefix}.{name}": param
for name, param in module.named_parameters(recurse=False)
}

# quantization first
if prefix in module_to_scheme:
state_dict = self.quantization_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
compression_device=exec_device,
)

# sparsity second
if prefix in sparse_compression_targets:
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
show_progress=False,
)
# sparsity second
if prefix in sparse_compression_targets:
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
show_progress=False,
)

# remove any existing parameters
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters(recurse=False)):
delete_offload_parameter(module, name)
# remove any existing parameters
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters(recurse=False)):
delete_offload_parameter(module, name)

# replace with compressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(onloading_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)
# replace with compressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(onloading_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

module.quantization_status = QuantizationStatus.COMPRESSED
module.quantization_status = QuantizationStatus.COMPRESSED

# TODO: consider sparse compression to also be compression
if (
Expand All @@ -451,55 +461,64 @@ def decompress_model(self, model: Module):
:param model: model containing parameters to compress
"""
module_to_scheme = map_module_to_scheme(model)
sparse_compression_targets: Set[str] = expand_target_names(
model=model,
targets=self.sparsity_config.targets if self.sparsity_config else [],
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
)

for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
if prefix in module_to_scheme or prefix in sparse_compression_targets:
# in the future, support decompression on same device
with align_module_device(module, execution_device="cpu"):
state_dict = {
f"{prefix}.{name}": param
for name, param in module.named_parameters(recurse=False)
}

# sparsity first
if prefix in sparse_compression_targets:
# sparse_compression_targets are automatically inferred by this fn
generator = self.sparsity_compressor.decompress_from_state_dict(
sparse_compression_targets = [
module_name
for module_name, _module in match_named_modules(
model=model,
targets=self.sparsity_config.targets if self.sparsity_config else [],
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
)
]

for prefix, module in tqdm(
match_named_modules(
model,
[*sparse_compression_targets, *module_to_scheme.keys()],
warn_on_fail=True,
),
desc="Decompressing model",
):
# in the future, support decompression on same device
with align_module_device(module, execution_device="cpu"):
state_dict = {
f"{prefix}.{name}": param
for name, param in module.named_parameters(recurse=False)
}

# sparsity first
if prefix in sparse_compression_targets:
# sparse_compression_targets are automatically inferred by this fn
generator = self.sparsity_compressor.decompress_from_state_dict(
state_dict,
)
# generates (param_path, param_val)
# of compressed and unused params
state_dict = {key: value for key, value in generator}

# quantization second
if prefix in module_to_scheme:
state_dict = (
self.quantization_compressor.decompress_module_from_state_dict(
prefix,
state_dict,
scheme=module_to_scheme[prefix],
)
# generates (param_path, param_val)
# of compressed and unused params
state_dict = {key: value for key, value in generator}

# quantization second
if prefix in module_to_scheme:
state_dict = (
self.quantization_compressor.decompress_module_from_state_dict(
prefix,
state_dict,
scheme=module_to_scheme[prefix],
)
)
)

# remove any existing parameters
exec_device = get_execution_device(module)
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters(recurse=False)):
delete_offload_parameter(module, name)
# remove any existing parameters
exec_device = get_execution_device(module)
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters(recurse=False)):
delete_offload_parameter(module, name)

# replace with decompressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)
# replace with decompressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

module.quantization_status = QuantizationStatus.FROZEN
module.quantization_status = QuantizationStatus.FROZEN

# ----- state dict compression pathways ----- #

Expand Down Expand Up @@ -535,11 +554,14 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_target_names(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
sparse_compression_targets: Set[str] = {
module_name
for module_name, _module in match_named_modules(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
}
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
Expand Down Expand Up @@ -598,7 +620,6 @@ def decompress(self, model_path: str, model: Module):
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):

names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
Expand Down
Loading