Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,8 @@ def __init__(
self.shared_cache_keys = get_shared_keys(self.model)

self.layer_config = layer_config

# should be set after loading model and set layer_config, cause some special scheme need these.
self.scheme, self.is_auto_scheme = self._parse_and_set_scheme(scheme, kwargs)

gguf_scheme_name = get_gguf_scheme(self.scheme)
# GGUF uses fp32 scale dtype as default
scale_dtype = kwargs.pop("scale_dtype", None)
if scale_dtype is None:
scale_dtype = "fp32" if gguf_scheme_name else "fp16"
self.scheme = scheme
self.scale_dtype = kwargs.pop("scale_dtype", None)

# Extra/legacy kwargs for backward compatibility
# Major version releases may pack them with extra configuration options
Expand All @@ -314,21 +307,12 @@ def __init__(
platform = "model_scope"
self.platform = platform
self.quant_lm_head = kwargs.pop("quant_lm_head", False)

self.ignore_layers = kwargs.pop("ignore_layers", "")
predefined_ignore_layers = get_predefined_ignore_layers(self.model)

if predefined_ignore_layers:
logger.info(f"Using predefined ignore_layers: {predefined_ignore_layers}")
tmp_str = ",".join(predefined_ignore_layers)
if self.ignore_layers == "":
self.ignore_layers = tmp_str
else:
self.ignore_layers += "," + tmp_str
self.supported_types = SUPPORTED_LAYER_TYPES
self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES
self.scale_dtype = convert_dtype_str2torch(scale_dtype)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.block_forward = block_forward

if kwargs:
logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.")
Expand Down Expand Up @@ -360,16 +344,10 @@ def __init__(
self.device_map = device_map
if isinstance(self.device_map, str):
self.device_map = self.device_map.replace(" ", "")

self.device_list = parse_available_devices(device_map)

# Set device, must place after model loading
self.device = get_major_device(device_map)
set_non_auto_device_map(self.model, self.device_map)
self.device = get_major_device(self.device_map)

# Tuning hyperparameters
self.seed = seed
set_seed(self.seed)
self.amp = amp
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
Expand Down Expand Up @@ -448,7 +426,6 @@ def __init__(
if self.static_attention_dtype is not None:
logger.warning("The static attention dtype is experimental and currently has limited support.")

self._set_amp_dtype()
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
if self.act_bits <= 8 and self.amp_dtype == torch.float16:
logger.warning("force to use bf16 to for quantization tuning when enabling activation quantization")
Expand All @@ -466,23 +443,18 @@ def __init__(

# after setting iters
self.enable_torch_compile = enable_torch_compile
self._adjust_torch_compile(enable_torch_compile)

self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward
self.attention_mask = []
self.wrapper_block = wrapper_block

self._check_configs()
torch.set_printoptions(precision=3, sci_mode=True)

if isinstance(scheme, AutoScheme):
self.layer_config = self._gen_auto_scheme(model, scheme, dataset, self.device_map)

if is_hpex_available():
logger.info("habana_frameworks is available, import htcore explicitly.")
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401]

self.attention_mask = []

self.wrapper_block = wrapper_block
if self.enable_alg_ext:
try:
logger.warning_once("using algorithm extension for quantization.")
Expand All @@ -492,6 +464,48 @@ def __init__(
except (ImportError, ModuleNotFoundError):
logger.error("algorithm extension import error, fallback to default mode")

self._post_inited = False

def _post_init(self) -> None:
"""Post-initialization for AutoRound."""
if self._post_inited:
return

# should be set after loading model and set layer_config, cause some special scheme need these.
self.scheme, self.is_auto_scheme = self._parse_and_set_scheme(self.scheme, {})

# GGUF uses fp32 scale dtype as default
if self.scale_dtype is None:
gguf_scheme_name = get_gguf_scheme(self.scheme)
scale_dtype = "fp32" if gguf_scheme_name else "fp16"
self.scale_dtype = convert_dtype_str2torch(scale_dtype)

predefined_ignore_layers = get_predefined_ignore_layers(self.model)

if predefined_ignore_layers:
logger.info(f"Using predefined ignore_layers: {predefined_ignore_layers}")
tmp_str = ",".join(predefined_ignore_layers)
if self.ignore_layers == "":
self.ignore_layers = tmp_str
else:
self.ignore_layers += "," + tmp_str

# Set device, must place after model loading
self._set_device(self.device_map)
set_non_auto_device_map(self.model, self.device_map)
self.device_list = parse_available_devices(self.device_map)

set_seed(self.seed)
self._set_amp_dtype()
self._adjust_torch_compile(self.enable_torch_compile)
if self.enable_torch_compile:
self.block_forward = compile_func(self.block_forward, self.device)

if isinstance(self.scheme, AutoScheme):
self.layer_config = self._gen_auto_scheme(self.model, self.scheme, self.dataset, self.device_map)

self._post_inited = True

def _gen_auto_scheme(
self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device]
) -> dict[str, dict]:
Expand Down Expand Up @@ -865,6 +879,9 @@ def quantize_and_save(
Raises:
ValueError: If an unsupported format is specified.
"""
# post init
self._post_init()

# Validate and process the specified formats
self.orig_output_dir = output_dir

Expand Down Expand Up @@ -3118,6 +3135,9 @@ def save_quantized(
Returns:
object: The compressed model object.
"""
# post init
self._post_init()

self.orig_output_dir = output_dir
if isinstance(format, str) and getattr(self, "formats", None) is None:
formats = get_formats(format, self)
Expand Down
Loading