Skip to content

Commit 72a6c65

Browse files
committed
Refactor usages of expand_target_names, is_target, and find_name_or_class_matches
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent ad74d32 commit 72a6c65

File tree

5 files changed

+205
-289
lines changed

5 files changed

+205
-289
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 130 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
apply_quantization_config,
4242
load_pretrained_quantization_parameters,
4343
)
44-
from compressed_tensors.quantization.lifecycle import expand_target_names
4544
from compressed_tensors.quantization.utils import is_module_quantized
45+
from compressed_tensors.utils.match import match_named_modules
4646
from compressed_tensors.utils import (
4747
align_module_device,
4848
delete_offload_parameter,
@@ -292,13 +292,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
292292
self.sparsity_compressor
293293
and self.sparsity_config.format != CompressionFormat.dense.value
294294
):
295-
sparse_targets = expand_target_names(
295+
sparse_targets = match_named_modules(
296296
model=model,
297297
targets=self.sparsity_config.targets,
298298
ignore=self.sparsity_config.ignore,
299299
)
300+
300301
missing_keys.update(
301-
merge_names(target, "weight") for target in sparse_targets
302+
merge_names(target_name, "weight")
303+
for target_name, _module in sparse_targets
302304
)
303305

304306
# Determine missing keys due to pack quantization
@@ -308,13 +310,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
308310
== CompressionFormat.pack_quantized.value
309311
):
310312
for scheme in self.quantization_config.config_groups.values():
311-
quant_targets = expand_target_names(
313+
quant_targets = match_named_modules(
312314
model=model,
313315
targets=scheme.targets,
314316
ignore=self.quantization_config.ignore,
315317
)
316318
missing_keys.update(
317-
merge_names(target, "weight") for target in quant_targets
319+
merge_names(target_name, "weight")
320+
for target_name, _module in quant_targets
318321
)
319322

320323
return list(missing_keys)
@@ -345,28 +348,28 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
345348
self.sparsity_compressor
346349
and self.sparsity_config.format != CompressionFormat.dense.value
347350
):
348-
sparse_targets: Set[str] = expand_target_names(
351+
sparse_targets = match_named_modules(
349352
model=model,
350353
targets=self.sparsity_config.targets,
351354
ignore=self.sparsity_config.ignore,
352355
)
353356
unexpected_keys.update(
354-
merge_names(target, param)
355-
for target in sparse_targets
357+
merge_names(target_name, param)
358+
for target_name, _module in sparse_targets
356359
for param in self.sparsity_compressor.compression_param_names
357360
)
358361

359362
# Identify unexpected keys from quantization compression
360363
if self.quantization_compressor:
361364
for scheme in self.quantization_config.config_groups.values():
362-
quant_targets: Set[str] = expand_target_names(
365+
quant_targets = match_named_modules(
363366
model=model,
364367
targets=scheme.targets,
365368
ignore=self.quantization_config.ignore,
366369
)
367370
unexpected_keys.update(
368-
merge_names(target, param)
369-
for target in quant_targets
371+
merge_names(target_name, param)
372+
for target_name, _module in quant_targets
370373
for param in self.quantization_compressor.compression_param_names
371374
if param != "weight"
372375
)
@@ -383,58 +386,65 @@ def compress_model(self, model: Module):
383386
:param model: model containing parameters to compress
384387
"""
385388
module_to_scheme = map_module_to_scheme(model)
386-
sparse_compression_targets: Set[str] = expand_target_names(
387-
model=model,
388-
targets=self.sparsity_config.targets if self.sparsity_config else [],
389-
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
390-
)
391-
392-
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
393-
394-
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395-
module_device = get_execution_device(module)
396-
is_meta = module_device.type == "meta"
397-
398-
exec_device = "meta" if is_meta else "cpu"
399-
onloading_device = "meta" if is_meta else module_device
400-
401-
# in the future, support compression on same device
402-
with align_module_device(module, execution_device=exec_device):
403-
state_dict = {
404-
f"{prefix}.{name}": param
405-
for name, param in module.named_parameters(recurse=False)
406-
}
407-
408-
# quantization first
409-
if prefix in module_to_scheme:
410-
state_dict = self.quantization_compressor.compress(
411-
state_dict,
412-
names_to_scheme=module_to_scheme,
413-
show_progress=False,
414-
compression_device=exec_device,
415-
)
389+
sparse_compression_targets = [
390+
module_name
391+
for module_name, _module in match_named_modules(
392+
model=model,
393+
targets=self.sparsity_config.targets if self.sparsity_config else [],
394+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
395+
)
396+
]
397+
for prefix, module in tqdm(
398+
match_named_modules(
399+
model,
400+
[*sparse_compression_targets, *module_to_scheme.keys()],
401+
warn_on_fail=True,
402+
),
403+
desc="Compressing model",
404+
):
405+
module_device = get_execution_device(module)
406+
is_meta = module_device.type == "meta"
407+
408+
exec_device = "meta" if is_meta else "cpu"
409+
onloading_device = "meta" if is_meta else module_device
410+
411+
# in the future, support compression on same device
412+
with align_module_device(module, execution_device=exec_device):
413+
state_dict = {
414+
f"{prefix}.{name}": param
415+
for name, param in module.named_parameters(recurse=False)
416+
}
417+
418+
# quantization first
419+
if prefix in module_to_scheme:
420+
state_dict = self.quantization_compressor.compress(
421+
state_dict,
422+
names_to_scheme=module_to_scheme,
423+
show_progress=False,
424+
compression_device=exec_device,
425+
)
416426

417-
# sparsity second
418-
if prefix in sparse_compression_targets:
419-
state_dict = self.sparsity_compressor.compress(
420-
state_dict,
421-
compression_targets=sparse_compression_targets,
422-
show_progress=False,
423-
)
427+
# sparsity second
428+
if prefix in sparse_compression_targets:
429+
state_dict = self.sparsity_compressor.compress(
430+
state_dict,
431+
compression_targets=sparse_compression_targets,
432+
show_progress=False,
433+
)
424434

425-
# remove any existing parameters
426-
offload_device = get_offloaded_device(module)
427-
for name, _ in list(module.named_parameters(recurse=False)):
428-
delete_offload_parameter(module, name)
435+
# remove any existing parameters
436+
offload_device = get_offloaded_device(module)
437+
for name, _ in list(module.named_parameters(recurse=False)):
438+
delete_offload_parameter(module, name)
429439

430-
# replace with compressed parameters
431-
for name, value in state_dict.items():
432-
name = name.removeprefix(f"{prefix}.")
433-
value = value.to(onloading_device)
434-
param = torch.nn.Parameter(value, requires_grad=False)
435-
register_offload_parameter(module, name, param, offload_device)
440+
# replace with compressed parameters
441+
for name, value in state_dict.items():
442+
name = name.removeprefix(f"{prefix}.")
443+
value = value.to(onloading_device)
444+
param = torch.nn.Parameter(value, requires_grad=False)
445+
register_offload_parameter(module, name, param, offload_device)
436446

437-
module.quantization_status = QuantizationStatus.COMPRESSED
447+
module.quantization_status = QuantizationStatus.COMPRESSED
438448

439449
# TODO: consider sparse compression to also be compression
440450
if (
@@ -451,55 +461,64 @@ def decompress_model(self, model: Module):
451461
:param model: model containing parameters to compress
452462
"""
453463
module_to_scheme = map_module_to_scheme(model)
454-
sparse_compression_targets: Set[str] = expand_target_names(
455-
model=model,
456-
targets=self.sparsity_config.targets if self.sparsity_config else [],
457-
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
458-
)
459-
460-
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
461-
if prefix in module_to_scheme or prefix in sparse_compression_targets:
462-
# in the future, support decompression on same device
463-
with align_module_device(module, execution_device="cpu"):
464-
state_dict = {
465-
f"{prefix}.{name}": param
466-
for name, param in module.named_parameters(recurse=False)
467-
}
468-
469-
# sparsity first
470-
if prefix in sparse_compression_targets:
471-
# sparse_compression_targets are automatically inferred by this fn
472-
generator = self.sparsity_compressor.decompress_from_state_dict(
464+
sparse_compression_targets = [
465+
module_name
466+
for module_name, _module in match_named_modules(
467+
model=model,
468+
targets=self.sparsity_config.targets if self.sparsity_config else [],
469+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
470+
)
471+
]
472+
473+
for prefix, module in tqdm(
474+
match_named_modules(
475+
model,
476+
[*sparse_compression_targets, *module_to_scheme.keys()],
477+
warn_on_fail=True,
478+
),
479+
desc="Decompressing model",
480+
):
481+
# in the future, support decompression on same device
482+
with align_module_device(module, execution_device="cpu"):
483+
state_dict = {
484+
f"{prefix}.{name}": param
485+
for name, param in module.named_parameters(recurse=False)
486+
}
487+
488+
# sparsity first
489+
if prefix in sparse_compression_targets:
490+
# sparse_compression_targets are automatically inferred by this fn
491+
generator = self.sparsity_compressor.decompress_from_state_dict(
492+
state_dict,
493+
)
494+
# generates (param_path, param_val)
495+
# of compressed and unused params
496+
state_dict = {key: value for key, value in generator}
497+
498+
# quantization second
499+
if prefix in module_to_scheme:
500+
state_dict = (
501+
self.quantization_compressor.decompress_module_from_state_dict(
502+
prefix,
473503
state_dict,
504+
scheme=module_to_scheme[prefix],
474505
)
475-
# generates (param_path, param_val)
476-
# of compressed and unused params
477-
state_dict = {key: value for key, value in generator}
478-
479-
# quantization second
480-
if prefix in module_to_scheme:
481-
state_dict = (
482-
self.quantization_compressor.decompress_module_from_state_dict(
483-
prefix,
484-
state_dict,
485-
scheme=module_to_scheme[prefix],
486-
)
487-
)
506+
)
488507

489-
# remove any existing parameters
490-
exec_device = get_execution_device(module)
491-
offload_device = get_offloaded_device(module)
492-
for name, _ in list(module.named_parameters(recurse=False)):
493-
delete_offload_parameter(module, name)
508+
# remove any existing parameters
509+
exec_device = get_execution_device(module)
510+
offload_device = get_offloaded_device(module)
511+
for name, _ in list(module.named_parameters(recurse=False)):
512+
delete_offload_parameter(module, name)
494513

495-
# replace with decompressed parameters
496-
for name, value in state_dict.items():
497-
name = name.removeprefix(f"{prefix}.")
498-
value = value.to(exec_device)
499-
param = torch.nn.Parameter(value, requires_grad=False)
500-
register_offload_parameter(module, name, param, offload_device)
514+
# replace with decompressed parameters
515+
for name, value in state_dict.items():
516+
name = name.removeprefix(f"{prefix}.")
517+
value = value.to(exec_device)
518+
param = torch.nn.Parameter(value, requires_grad=False)
519+
register_offload_parameter(module, name, param, offload_device)
501520

502-
module.quantization_status = QuantizationStatus.FROZEN
521+
module.quantization_status = QuantizationStatus.FROZEN
503522

504523
# ----- state dict compression pathways ----- #
505524

@@ -535,11 +554,14 @@ def compress(
535554
)
536555

537556
if self.sparsity_compressor is not None:
538-
sparse_compression_targets: Set[str] = expand_target_names(
539-
model=model,
540-
targets=self.sparsity_config.targets,
541-
ignore=self.sparsity_config.ignore,
542-
)
557+
sparse_compression_targets: Set[str] = {
558+
module_name
559+
for module_name, _module in match_named_modules(
560+
model=model,
561+
targets=self.sparsity_config.targets,
562+
ignore=self.sparsity_config.ignore,
563+
)
564+
}
543565
state_dict = self.sparsity_compressor.compress(
544566
state_dict,
545567
compression_targets=sparse_compression_targets,
@@ -598,7 +620,6 @@ def decompress(self, model_path: str, model: Module):
598620
with override_quantization_status(
599621
self.quantization_config, QuantizationStatus.FROZEN
600622
):
601-
602623
names_to_scheme = apply_quantization_config(
603624
model, self.quantization_config
604625
)

0 commit comments

Comments
 (0)