Skip to content

Commit 8ce019d

Browse files
mvanhornclaude
andcommitted
style: fix pylint warnings in estimation module
Refactor _count_parameters into smaller helpers to reduce local variable count. Convert dry_run_estimate to use **kwargs and extract helpers for config loading and result building. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
1 parent 860303b commit 8ce019d

File tree

1 file changed

+96
-59
lines changed

1 file changed

+96
-59
lines changed

auto_round/estimation.py

Lines changed: 96 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,53 +38,54 @@
3838
_SECS_PER_LAYER_PER_ITER = 0.12
3939

4040

41-
def _count_parameters(config):
41+
def _count_parameters(config): # pylint: disable=too-many-locals
4242
"""Estimate total parameter count from a transformers model config.
4343
4444
Uses hidden_size, intermediate_size, num_hidden_layers, and vocab_size
4545
to compute a rough parameter count. Falls back to a simple
4646
hidden_size^2 * num_layers heuristic when fields are missing.
4747
"""
4848
hidden = getattr(config, "hidden_size", None)
49-
intermediate = getattr(config, "intermediate_size", None)
5049
num_layers = getattr(config, "num_hidden_layers", None)
51-
vocab_size = getattr(config, "vocab_size", None)
52-
num_attention_heads = getattr(config, "num_attention_heads", None)
53-
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
54-
5550
if hidden is None or num_layers is None:
5651
return None
5752

58-
# Attention: Q, K, V projections + output projection
59-
head_dim = hidden // num_attention_heads if num_attention_heads else hidden
60-
q_params = hidden * hidden # Q projection
61-
k_params = hidden * (num_key_value_heads * head_dim if num_key_value_heads else hidden)
62-
v_params = k_params
63-
o_params = hidden * hidden # output projection
64-
attn_params = q_params + k_params + v_params + o_params
65-
66-
# FFN: gate + up + down projections (for gated architectures like LLaMA)
67-
if intermediate is not None:
68-
ffn_params = 3 * hidden * intermediate # gate_proj + up_proj + down_proj
69-
else:
70-
ffn_params = 4 * hidden * hidden # classic 4x expansion
53+
intermediate = getattr(config, "intermediate_size", None)
54+
vocab_size = getattr(config, "vocab_size", None)
55+
num_heads = getattr(config, "num_attention_heads", None)
56+
num_kv_heads = getattr(config, "num_key_value_heads", num_heads)
7157

72-
# Per-layer params (attention + ffn + layer norms)
58+
attn_params = _count_attention_params(hidden, num_heads, num_kv_heads)
59+
ffn_params = _count_ffn_params(hidden, intermediate)
7360
layer_params = attn_params + ffn_params + 2 * hidden # 2 layer norms
7461

7562
total = num_layers * layer_params
63+
total += _count_embedding_params(config, hidden, vocab_size)
64+
return total
7665

77-
# Embedding + LM head
78-
if vocab_size is not None:
79-
embedding_params = vocab_size * hidden
80-
# Most models tie embeddings and lm_head
81-
tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
82-
if tie_word_embeddings:
83-
total += embedding_params
84-
else:
85-
total += 2 * embedding_params
8666

87-
return total
67+
def _count_attention_params(hidden, num_heads, num_kv_heads):
68+
"""Count attention layer parameters (Q, K, V, O projections)."""
69+
head_dim = hidden // num_heads if num_heads else hidden
70+
kv_dim = num_kv_heads * head_dim if num_kv_heads else hidden
71+
return hidden * hidden + 2 * hidden * kv_dim + hidden * hidden
72+
73+
74+
def _count_ffn_params(hidden, intermediate):
75+
"""Count FFN layer parameters."""
76+
if intermediate is not None:
77+
return 3 * hidden * intermediate # gate + up + down
78+
return 4 * hidden * hidden # classic 4x expansion
79+
80+
81+
def _count_embedding_params(config, hidden, vocab_size):
82+
"""Count embedding and LM head parameters."""
83+
if vocab_size is None:
84+
return 0
85+
embedding_params = vocab_size * hidden
86+
if getattr(config, "tie_word_embeddings", True):
87+
return embedding_params
88+
return 2 * embedding_params
8889

8990

9091
def _format_bytes(num_bytes):
@@ -166,52 +167,81 @@ def estimate_time(num_layers, iters, nsamples, batch_size):
166167
return total_seconds
167168

168169

169-
def dry_run_estimate(model_name, scheme_bits, group_size, model_dtype="float16",
170-
batch_size=8, seqlen=2048, nsamples=128, iters=200,
171-
trust_remote_code=True, platform="hf"):
170+
_DRY_RUN_DEFAULTS = {
171+
"model_dtype": "float16",
172+
"batch_size": 8,
173+
"seqlen": 2048,
174+
"nsamples": 128,
175+
"iters": 200,
176+
"trust_remote_code": True,
177+
"platform": "hf",
178+
}
179+
180+
181+
def dry_run_estimate(model_name, scheme_bits, group_size, **kwargs):
172182
"""Run a dry-run estimation and return a dict of estimates.
173183
174184
Args:
175185
model_name: HuggingFace model name or local path.
176186
scheme_bits: Target quantization bit width (e.g. 4 for W4A16).
177187
group_size: Quantization group size.
178-
model_dtype: Original model data type string.
179-
batch_size: Calibration batch size.
180-
seqlen: Calibration sequence length.
181-
nsamples: Number of calibration samples.
182-
iters: Number of tuning iterations.
183-
trust_remote_code: Whether to trust remote code when loading config.
184-
platform: Platform to load model config from.
188+
**kwargs: Optional overrides - model_dtype, batch_size, seqlen,
189+
nsamples, iters, trust_remote_code, platform.
185190
186191
Returns:
187192
dict with keys: param_count, peak_vram_bytes, output_size_bytes,
188193
estimated_time_secs, and their formatted string versions.
189194
"""
190-
if platform == "model_scope":
191-
from modelscope import AutoConfig
192-
else:
193-
from transformers import AutoConfig
194-
195-
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
195+
opts = {**_DRY_RUN_DEFAULTS, **kwargs}
196+
config = _load_model_config(model_name, opts)
196197

197198
param_count = _count_parameters(config)
198199
if param_count is None:
199-
logger.warning("Could not estimate parameter count from model config.")
200+
logger.warning(
201+
"Could not estimate parameter count from model config."
202+
)
200203
return None
201204

202-
hidden_size = getattr(config, "hidden_size", 4096)
203-
num_layers = getattr(config, "num_hidden_layers", 32)
205+
return _build_estimate_result(
206+
model_name, scheme_bits, group_size, param_count, config, opts
207+
)
204208

205-
dtype_bytes = DTYPE_BYTES.get(model_dtype, 2)
206209

207-
peak_vram = estimate_vram(param_count, dtype_bytes, batch_size, seqlen, hidden_size)
208-
output_size = estimate_output_size(param_count, scheme_bits, group_size)
209-
est_time = estimate_time(num_layers, iters, nsamples, batch_size)
210+
def _load_model_config(model_name, opts):
211+
"""Load model config from the specified platform."""
212+
auto_config = _load_auto_config(opts["platform"])
213+
return auto_config.from_pretrained(
214+
model_name, trust_remote_code=opts["trust_remote_code"]
215+
)
210216

217+
218+
def _build_estimate_result( # pylint: disable=too-many-arguments,too-many-positional-arguments
219+
model_name, scheme_bits, group_size, param_count, config, opts
220+
):
221+
"""Build the estimation result dictionary."""
222+
hidden_size = getattr(config, "hidden_size", 4096)
223+
num_layers = getattr(config, "num_hidden_layers", 32)
224+
dtype_bytes = DTYPE_BYTES.get(opts["model_dtype"], 2)
225+
226+
peak_vram = estimate_vram(
227+
param_count, dtype_bytes,
228+
opts["batch_size"], opts["seqlen"], hidden_size
229+
)
230+
output_size = estimate_output_size(
231+
param_count, scheme_bits, group_size
232+
)
233+
est_time = estimate_time(
234+
num_layers, opts["iters"], opts["nsamples"], opts["batch_size"]
235+
)
236+
237+
param_str = (
238+
f"{param_count / 1e9:.2f}B" if param_count >= 1e9
239+
else f"{param_count / 1e6:.1f}M"
240+
)
211241
return {
212242
"model_name": model_name,
213243
"param_count": param_count,
214-
"param_count_str": f"{param_count / 1e9:.2f}B" if param_count >= 1e9 else f"{param_count / 1e6:.1f}M",
244+
"param_count_str": param_str,
215245
"peak_vram_bytes": peak_vram,
216246
"peak_vram_str": _format_bytes(peak_vram),
217247
"output_size_bytes": output_size,
@@ -220,15 +250,22 @@ def dry_run_estimate(model_name, scheme_bits, group_size, model_dtype="float16",
220250
"estimated_time_str": _format_time(est_time),
221251
"scheme_bits": scheme_bits,
222252
"group_size": group_size,
223-
"model_dtype": model_dtype,
224-
"batch_size": batch_size,
225-
"seqlen": seqlen,
226-
"nsamples": nsamples,
227-
"iters": iters,
228253
"num_layers": num_layers,
254+
**{k: opts[k] for k in (
255+
"model_dtype", "batch_size", "seqlen", "nsamples", "iters"
256+
)},
229257
}
230258

231259

260+
def _load_auto_config(platform):
261+
"""Load the appropriate AutoConfig class for the platform."""
262+
if platform == "model_scope":
263+
from modelscope import AutoConfig # pylint: disable=import-outside-toplevel
264+
else:
265+
from transformers import AutoConfig # pylint: disable=import-outside-toplevel
266+
return AutoConfig
267+
268+
232269
def print_dry_run_report(estimates):
233270
"""Print a formatted dry-run estimation report to stdout."""
234271
if estimates is None:

0 commit comments

Comments
 (0)