3838_SECS_PER_LAYER_PER_ITER = 0.12
3939
4040
41- def _count_parameters (config ):
41+ def _count_parameters (config ): # pylint: disable=too-many-locals
4242 """Estimate total parameter count from a transformers model config.
4343
4444 Uses hidden_size, intermediate_size, num_hidden_layers, and vocab_size
4545 to compute a rough parameter count. Falls back to a simple
4646 hidden_size^2 * num_layers heuristic when fields are missing.
4747 """
4848 hidden = getattr (config , "hidden_size" , None )
49- intermediate = getattr (config , "intermediate_size" , None )
5049 num_layers = getattr (config , "num_hidden_layers" , None )
51- vocab_size = getattr (config , "vocab_size" , None )
52- num_attention_heads = getattr (config , "num_attention_heads" , None )
53- num_key_value_heads = getattr (config , "num_key_value_heads" , num_attention_heads )
54-
5550 if hidden is None or num_layers is None :
5651 return None
5752
58- # Attention: Q, K, V projections + output projection
59- head_dim = hidden // num_attention_heads if num_attention_heads else hidden
60- q_params = hidden * hidden # Q projection
61- k_params = hidden * (num_key_value_heads * head_dim if num_key_value_heads else hidden )
62- v_params = k_params
63- o_params = hidden * hidden # output projection
64- attn_params = q_params + k_params + v_params + o_params
65-
66- # FFN: gate + up + down projections (for gated architectures like LLaMA)
67- if intermediate is not None :
68- ffn_params = 3 * hidden * intermediate # gate_proj + up_proj + down_proj
69- else :
70- ffn_params = 4 * hidden * hidden # classic 4x expansion
53+ intermediate = getattr (config , "intermediate_size" , None )
54+ vocab_size = getattr (config , "vocab_size" , None )
55+ num_heads = getattr (config , "num_attention_heads" , None )
56+ num_kv_heads = getattr (config , "num_key_value_heads" , num_heads )
7157
72- # Per-layer params (attention + ffn + layer norms)
58+ attn_params = _count_attention_params (hidden , num_heads , num_kv_heads )
59+ ffn_params = _count_ffn_params (hidden , intermediate )
7360 layer_params = attn_params + ffn_params + 2 * hidden # 2 layer norms
7461
7562 total = num_layers * layer_params
63+ total += _count_embedding_params (config , hidden , vocab_size )
64+ return total
7665
77- # Embedding + LM head
78- if vocab_size is not None :
79- embedding_params = vocab_size * hidden
80- # Most models tie embeddings and lm_head
81- tie_word_embeddings = getattr (config , "tie_word_embeddings" , True )
82- if tie_word_embeddings :
83- total += embedding_params
84- else :
85- total += 2 * embedding_params
8666
87- return total
67+ def _count_attention_params (hidden , num_heads , num_kv_heads ):
68+ """Count attention layer parameters (Q, K, V, O projections)."""
69+ head_dim = hidden // num_heads if num_heads else hidden
70+ kv_dim = num_kv_heads * head_dim if num_kv_heads else hidden
71+ return hidden * hidden + 2 * hidden * kv_dim + hidden * hidden
72+
73+
74+ def _count_ffn_params (hidden , intermediate ):
75+ """Count FFN layer parameters."""
76+ if intermediate is not None :
77+ return 3 * hidden * intermediate # gate + up + down
78+ return 4 * hidden * hidden # classic 4x expansion
79+
80+
81+ def _count_embedding_params (config , hidden , vocab_size ):
82+ """Count embedding and LM head parameters."""
83+ if vocab_size is None :
84+ return 0
85+ embedding_params = vocab_size * hidden
86+ if getattr (config , "tie_word_embeddings" , True ):
87+ return embedding_params
88+ return 2 * embedding_params
8889
8990
9091def _format_bytes (num_bytes ):
@@ -166,52 +167,81 @@ def estimate_time(num_layers, iters, nsamples, batch_size):
166167 return total_seconds
167168
168169
169- def dry_run_estimate (model_name , scheme_bits , group_size , model_dtype = "float16" ,
170- batch_size = 8 , seqlen = 2048 , nsamples = 128 , iters = 200 ,
171- trust_remote_code = True , platform = "hf" ):
170+ _DRY_RUN_DEFAULTS = {
171+ "model_dtype" : "float16" ,
172+ "batch_size" : 8 ,
173+ "seqlen" : 2048 ,
174+ "nsamples" : 128 ,
175+ "iters" : 200 ,
176+ "trust_remote_code" : True ,
177+ "platform" : "hf" ,
178+ }
179+
180+
181+ def dry_run_estimate (model_name , scheme_bits , group_size , ** kwargs ):
172182 """Run a dry-run estimation and return a dict of estimates.
173183
174184 Args:
175185 model_name: HuggingFace model name or local path.
176186 scheme_bits: Target quantization bit width (e.g. 4 for W4A16).
177187 group_size: Quantization group size.
178- model_dtype: Original model data type string.
179- batch_size: Calibration batch size.
180- seqlen: Calibration sequence length.
181- nsamples: Number of calibration samples.
182- iters: Number of tuning iterations.
183- trust_remote_code: Whether to trust remote code when loading config.
184- platform: Platform to load model config from.
188+ **kwargs: Optional overrides - model_dtype, batch_size, seqlen,
189+ nsamples, iters, trust_remote_code, platform.
185190
186191 Returns:
187192 dict with keys: param_count, peak_vram_bytes, output_size_bytes,
188193 estimated_time_secs, and their formatted string versions.
189194 """
190- if platform == "model_scope" :
191- from modelscope import AutoConfig
192- else :
193- from transformers import AutoConfig
194-
195- config = AutoConfig .from_pretrained (model_name , trust_remote_code = trust_remote_code )
195+ opts = {** _DRY_RUN_DEFAULTS , ** kwargs }
196+ config = _load_model_config (model_name , opts )
196197
197198 param_count = _count_parameters (config )
198199 if param_count is None :
199- logger .warning ("Could not estimate parameter count from model config." )
200+ logger .warning (
201+ "Could not estimate parameter count from model config."
202+ )
200203 return None
201204
202- hidden_size = getattr (config , "hidden_size" , 4096 )
203- num_layers = getattr (config , "num_hidden_layers" , 32 )
205+ return _build_estimate_result (
206+ model_name , scheme_bits , group_size , param_count , config , opts
207+ )
204208
205- dtype_bytes = DTYPE_BYTES .get (model_dtype , 2 )
206209
207- peak_vram = estimate_vram (param_count , dtype_bytes , batch_size , seqlen , hidden_size )
208- output_size = estimate_output_size (param_count , scheme_bits , group_size )
209- est_time = estimate_time (num_layers , iters , nsamples , batch_size )
210+ def _load_model_config (model_name , opts ):
211+ """Load model config from the specified platform."""
212+ auto_config = _load_auto_config (opts ["platform" ])
213+ return auto_config .from_pretrained (
214+ model_name , trust_remote_code = opts ["trust_remote_code" ]
215+ )
210216
217+
218+ def _build_estimate_result ( # pylint: disable=too-many-arguments,too-many-positional-arguments
219+ model_name , scheme_bits , group_size , param_count , config , opts
220+ ):
221+ """Build the estimation result dictionary."""
222+ hidden_size = getattr (config , "hidden_size" , 4096 )
223+ num_layers = getattr (config , "num_hidden_layers" , 32 )
224+ dtype_bytes = DTYPE_BYTES .get (opts ["model_dtype" ], 2 )
225+
226+ peak_vram = estimate_vram (
227+ param_count , dtype_bytes ,
228+ opts ["batch_size" ], opts ["seqlen" ], hidden_size
229+ )
230+ output_size = estimate_output_size (
231+ param_count , scheme_bits , group_size
232+ )
233+ est_time = estimate_time (
234+ num_layers , opts ["iters" ], opts ["nsamples" ], opts ["batch_size" ]
235+ )
236+
237+ param_str = (
238+ f"{ param_count / 1e9 :.2f} B" if param_count >= 1e9
239+ else f"{ param_count / 1e6 :.1f} M"
240+ )
211241 return {
212242 "model_name" : model_name ,
213243 "param_count" : param_count ,
214- "param_count_str" : f" { param_count / 1e9 :.2f } B" if param_count >= 1e9 else f" { param_count / 1e6 :.1f } M" ,
244+ "param_count_str" : param_str ,
215245 "peak_vram_bytes" : peak_vram ,
216246 "peak_vram_str" : _format_bytes (peak_vram ),
217247 "output_size_bytes" : output_size ,
@@ -220,15 +250,22 @@ def dry_run_estimate(model_name, scheme_bits, group_size, model_dtype="float16",
220250 "estimated_time_str" : _format_time (est_time ),
221251 "scheme_bits" : scheme_bits ,
222252 "group_size" : group_size ,
223- "model_dtype" : model_dtype ,
224- "batch_size" : batch_size ,
225- "seqlen" : seqlen ,
226- "nsamples" : nsamples ,
227- "iters" : iters ,
228253 "num_layers" : num_layers ,
254+ ** {k : opts [k ] for k in (
255+ "model_dtype" , "batch_size" , "seqlen" , "nsamples" , "iters"
256+ )},
229257 }
230258
231259
260+ def _load_auto_config (platform ):
261+ """Load the appropriate AutoConfig class for the platform."""
262+ if platform == "model_scope" :
263+ from modelscope import AutoConfig # pylint: disable=import-outside-toplevel
264+ else :
265+ from transformers import AutoConfig # pylint: disable=import-outside-toplevel
266+ return AutoConfig
267+
268+
232269def print_dry_run_report (estimates ):
233270 """Print a formatted dry-run estimation report to stdout."""
234271 if estimates is None :
0 commit comments