@@ -256,7 +256,154 @@ def __init__(
256256 print (f'{ sum (p .numel () for p in self .world_model .transformer .parameters ())} parameters in agent.world_model.transformer' )
257257 print (f'{ sum (p .numel () for p in self .tokenizer .encoder .parameters ())} parameters in agent.tokenizer.encoder' )
258258 print ('==' * 20 )
259+ self ._log_model_parameters (world_model_cfg .obs_type )
259260
261+ def _log_model_parameters (self , obs_type : str ) -> None :
262+ """
263+ Overview:
264+ Logs detailed parameter counts for all model components with a comprehensive breakdown.
265+ Includes encoder, transformer, prediction heads, and other components.
266+ Arguments:
267+ - obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory').
268+ """
269+ from ding .utils import get_rank
270+
271+ # Only print from rank 0 to avoid duplicate logs in DDP
272+ if get_rank () != 0 :
273+ return
274+
275+ print ('=' * 80 )
276+ print ('MODEL PARAMETER STATISTICS' .center (80 ))
277+ print ('=' * 80 )
278+
279+ # --- Total Model Parameters ---
280+ total_params = sum (p .numel () for p in self .parameters ())
281+ total_trainable = sum (p .numel () for p in self .parameters () if p .requires_grad )
282+ print (f'\n { "TOTAL MODEL" :<40} { total_params :>15,} parameters' )
283+ print (f'{ " └─ Trainable" :<40} { total_trainable :>15,} parameters' )
284+ print (f'{ " └─ Frozen" :<40} { total_params - total_trainable :>15,} parameters' )
285+
286+ # --- World Model Components ---
287+ print (f'\n { "-" * 80 } ' )
288+ print (f'{ "WORLD MODEL BREAKDOWN" :<40} ' )
289+ print (f'{ "-" * 80 } ' )
290+
291+ wm_params = sum (p .numel () for p in self .world_model .parameters ())
292+ wm_trainable = sum (p .numel () for p in self .world_model .parameters () if p .requires_grad )
293+ print (f'{ "World Model Total" :<40} { wm_params :>15,} parameters' )
294+ print (f'{ " └─ Trainable" :<40} { wm_trainable :>15,} parameters ({ 100 * wm_trainable / wm_params :.1f} %)' )
295+
296+ # --- Encoder ---
297+ encoder_params = sum (p .numel () for p in self .tokenizer .encoder .parameters ())
298+ encoder_trainable = sum (p .numel () for p in self .tokenizer .encoder .parameters () if p .requires_grad )
299+ print (f'\n { "1. ENCODER (Tokenizer)" :<40} { encoder_params :>15,} parameters' )
300+ print (f'{ " └─ Trainable" :<40} { encoder_trainable :>15,} parameters ({ 100 * encoder_trainable / encoder_params :.1f} %)' )
301+
302+ # --- Transformer Backbone ---
303+ transformer_params = sum (p .numel () for p in self .world_model .transformer .parameters ())
304+ transformer_trainable = sum (p .numel () for p in self .world_model .transformer .parameters () if p .requires_grad )
305+ print (f'\n { "2. TRANSFORMER BACKBONE" :<40} { transformer_params :>15,} parameters' )
306+ print (f'{ " └─ Trainable" :<40} { transformer_trainable :>15,} parameters ({ 100 * transformer_trainable / transformer_params :.1f} %)' )
307+
308+ # --- Prediction Heads (Detailed Breakdown) ---
309+ print (f'\n { "3. PREDICTION HEADS" :<40} ' )
310+
311+ # Access head_dict from world_model
312+ if hasattr (self .world_model , 'head_dict' ):
313+ head_dict = self .world_model .head_dict
314+
315+ # Calculate total heads parameters
316+ total_heads_params = sum (p .numel () for module in head_dict .values () for p in module .parameters ())
317+ total_heads_trainable = sum (p .numel () for module in head_dict .values () for p in module .parameters () if p .requires_grad )
318+ print (f'{ " Total (All Heads)" :<40} { total_heads_params :>15,} parameters' )
319+ print (f'{ " └─ Trainable" :<40} { total_heads_trainable :>15,} parameters ({ 100 * total_heads_trainable / total_heads_params :.1f} %)' )
320+
321+ # Breakdown by head type
322+ head_names_map = {
323+ 'head_policy_multi_task' : 'Policy Head' ,
324+ 'head_value_multi_task' : 'Value Head' ,
325+ 'head_rewards_multi_task' : 'Reward Head' ,
326+ 'head_observations_multi_task' : 'Next Latent (Obs) Head'
327+ }
328+
329+ print (f'\n { " Breakdown by Head Type:" :<40} ' )
330+ for head_key , head_name in head_names_map .items ():
331+ if head_key in head_dict :
332+ head_module = head_dict [head_key ]
333+ head_params = sum (p .numel () for p in head_module .parameters ())
334+ head_trainable = sum (p .numel () for p in head_module .parameters () if p .requires_grad )
335+
336+ # Count number of task-specific heads (for ModuleList)
337+ if isinstance (head_module , nn .ModuleList ):
338+ num_heads = len (head_module )
339+ params_per_head = head_params // num_heads if num_heads > 0 else 0
340+ print (f'{ " ├─ " + head_name :<38} { head_params :>15,} parameters' )
341+ print (f'{ " └─ " + f"{ num_heads } task-specific heads" :<38} { params_per_head :>15,} params/head' )
342+ else :
343+ print (f'{ " ├─ " + head_name :<38} { head_params :>15,} parameters' )
344+ print (f'{ " └─ Shared across tasks" :<38} ' )
345+
346+ # --- Positional & Task Embeddings ---
347+ print (f'\n { "4. EMBEDDINGS" :<40} ' )
348+
349+ if hasattr (self .world_model , 'pos_emb' ):
350+ pos_emb_params = sum (p .numel () for p in self .world_model .pos_emb .parameters ())
351+ pos_emb_trainable = sum (p .numel () for p in self .world_model .pos_emb .parameters () if p .requires_grad )
352+ print (f'{ " ├─ Positional Embedding" :<40} { pos_emb_params :>15,} parameters' )
353+ if pos_emb_trainable == 0 :
354+ print (f'{ " └─ (Frozen)" :<40} ' )
355+
356+ if hasattr (self .world_model , 'task_emb' ) and self .world_model .task_emb is not None :
357+ task_emb_params = sum (p .numel () for p in self .world_model .task_emb .parameters ())
358+ task_emb_trainable = sum (p .numel () for p in self .world_model .task_emb .parameters () if p .requires_grad )
359+ print (f'{ " ├─ Task Embedding" :<40} { task_emb_params :>15,} parameters' )
360+ print (f'{ " └─ Trainable" :<40} { task_emb_trainable :>15,} parameters' )
361+
362+ if hasattr (self .world_model , 'act_embedding_table' ):
363+ act_emb_params = sum (p .numel () for p in self .world_model .act_embedding_table .parameters ())
364+ act_emb_trainable = sum (p .numel () for p in self .world_model .act_embedding_table .parameters () if p .requires_grad )
365+ print (f'{ " └─ Action Embedding" :<40} { act_emb_params :>15,} parameters' )
366+ print (f'{ " └─ Trainable" :<40} { act_emb_trainable :>15,} parameters' )
367+
368+ # --- Decoder (if applicable) ---
369+ if obs_type in ['vector' , 'image_memory' ] and self .tokenizer .decoder_network is not None :
370+ print (f'\n { "5. DECODER" :<40} ' )
371+ decoder_params = sum (p .numel () for p in self .tokenizer .decoder_network .parameters ())
372+ decoder_trainable = sum (p .numel () for p in self .tokenizer .decoder_network .parameters () if p .requires_grad )
373+ print (f'{ " Decoder Network" :<40} { decoder_params :>15,} parameters' )
374+ print (f'{ " └─ Trainable" :<40} { decoder_trainable :>15,} parameters' )
375+
376+ if obs_type == 'image_memory' and hasattr (self .tokenizer , 'lpips' ):
377+ lpips_params = sum (p .numel () for p in self .tokenizer .lpips .parameters ())
378+ print (f'{ " LPIPS Loss Network" :<40} { lpips_params :>15,} parameters' )
379+
380+ # Calculate world model params excluding decoder and LPIPS
381+ params_without_decoder = wm_params - decoder_params - lpips_params
382+ print (f'\n { " World Model (exc. Decoder & LPIPS)" :<40} { params_without_decoder :>15,} parameters' )
383+
384+ # --- Summary Table ---
385+ print (f'\n { "=" * 80 } ' )
386+ print (f'{ "SUMMARY" :<40} ' )
387+ print (f'{ "=" * 80 } ' )
388+ print (f'{ "Component" :<30} { "Total Params" :>15} { "Trainable" :>15} { "% of Total" :>15} ' )
389+ print (f'{ "-" * 80 } ' )
390+
391+ components = [
392+ ("Encoder" , encoder_params , encoder_trainable ),
393+ ("Transformer" , transformer_params , transformer_trainable ),
394+ ]
395+
396+ if hasattr (self .world_model , 'head_dict' ):
397+ components .append (("Prediction Heads" , total_heads_params , total_heads_trainable ))
398+
399+ for name , total , trainable in components :
400+ pct = 100 * total / total_params if total_params > 0 else 0
401+ print (f'{ name :<30} { total :>15,} { trainable :>15,} { pct :>14.1f} %' )
402+
403+ print (f'{ "=" * 80 } ' )
404+ print (f'{ "TOTAL" :<30} { total_params :>15,} { total_trainable :>15,} { "100.0%" :>15} ' )
405+ print (f'{ "=" * 80 } \n ' )
406+
260407 #@profile
261408 def initial_inference (self , obs_batch : torch .Tensor , action_batch = None , current_obs_batch = None , task_id = None ) -> MZNetworkOutput :
262409 """
0 commit comments