Skip to content

Commit e0a498e

Browse files
committed
add _log_model_parameters and polish LN
1 parent 1b52f03 commit e0a498e

File tree

2 files changed

+149
-2
lines changed

2 files changed

+149
-2
lines changed

lzero/model/unizero_model_multitask.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,154 @@ def __init__(
256256
print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer')
257257
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
258258
print('==' * 20)
259+
self._log_model_parameters(world_model_cfg.obs_type)
259260

261+
def _log_model_parameters(self, obs_type: str) -> None:
262+
"""
263+
Overview:
264+
Logs detailed parameter counts for all model components with a comprehensive breakdown.
265+
Includes encoder, transformer, prediction heads, and other components.
266+
Arguments:
267+
- obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory').
268+
"""
269+
from ding.utils import get_rank
270+
271+
# Only print from rank 0 to avoid duplicate logs in DDP
272+
if get_rank() != 0:
273+
return
274+
275+
print('=' * 80)
276+
print('MODEL PARAMETER STATISTICS'.center(80))
277+
print('=' * 80)
278+
279+
# --- Total Model Parameters ---
280+
total_params = sum(p.numel() for p in self.parameters())
281+
total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
282+
print(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters')
283+
print(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters')
284+
print(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters')
285+
286+
# --- World Model Components ---
287+
print(f'\n{"-" * 80}')
288+
print(f'{"WORLD MODEL BREAKDOWN":<40}')
289+
print(f'{"-" * 80}')
290+
291+
wm_params = sum(p.numel() for p in self.world_model.parameters())
292+
wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad)
293+
print(f'{"World Model Total":<40} {wm_params:>15,} parameters')
294+
print(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)')
295+
296+
# --- Encoder ---
297+
encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters())
298+
encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad)
299+
print(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters')
300+
print(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)')
301+
302+
# --- Transformer Backbone ---
303+
transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters())
304+
transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad)
305+
print(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters')
306+
print(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)')
307+
308+
# --- Prediction Heads (Detailed Breakdown) ---
309+
print(f'\n{"3. PREDICTION HEADS":<40}')
310+
311+
# Access head_dict from world_model
312+
if hasattr(self.world_model, 'head_dict'):
313+
head_dict = self.world_model.head_dict
314+
315+
# Calculate total heads parameters
316+
total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters())
317+
total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad)
318+
print(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters')
319+
print(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)')
320+
321+
# Breakdown by head type
322+
head_names_map = {
323+
'head_policy_multi_task': 'Policy Head',
324+
'head_value_multi_task': 'Value Head',
325+
'head_rewards_multi_task': 'Reward Head',
326+
'head_observations_multi_task': 'Next Latent (Obs) Head'
327+
}
328+
329+
print(f'\n{" Breakdown by Head Type:":<40}')
330+
for head_key, head_name in head_names_map.items():
331+
if head_key in head_dict:
332+
head_module = head_dict[head_key]
333+
head_params = sum(p.numel() for p in head_module.parameters())
334+
head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad)
335+
336+
# Count number of task-specific heads (for ModuleList)
337+
if isinstance(head_module, nn.ModuleList):
338+
num_heads = len(head_module)
339+
params_per_head = head_params // num_heads if num_heads > 0 else 0
340+
print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters')
341+
print(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head')
342+
else:
343+
print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters')
344+
print(f'{" └─ Shared across tasks":<38}')
345+
346+
# --- Positional & Task Embeddings ---
347+
print(f'\n{"4. EMBEDDINGS":<40}')
348+
349+
if hasattr(self.world_model, 'pos_emb'):
350+
pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters())
351+
pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad)
352+
print(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters')
353+
if pos_emb_trainable == 0:
354+
print(f'{" └─ (Frozen)":<40}')
355+
356+
if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None:
357+
task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters())
358+
task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad)
359+
print(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters')
360+
print(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters')
361+
362+
if hasattr(self.world_model, 'act_embedding_table'):
363+
act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters())
364+
act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad)
365+
print(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters')
366+
print(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters')
367+
368+
# --- Decoder (if applicable) ---
369+
if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None:
370+
print(f'\n{"5. DECODER":<40}')
371+
decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters())
372+
decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad)
373+
print(f'{" Decoder Network":<40} {decoder_params:>15,} parameters')
374+
print(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters')
375+
376+
if obs_type == 'image_memory' and hasattr(self.tokenizer, 'lpips'):
377+
lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters())
378+
print(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters')
379+
380+
# Calculate world model params excluding decoder and LPIPS
381+
params_without_decoder = wm_params - decoder_params - lpips_params
382+
print(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters')
383+
384+
# --- Summary Table ---
385+
print(f'\n{"=" * 80}')
386+
print(f'{"SUMMARY":<40}')
387+
print(f'{"=" * 80}')
388+
print(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}')
389+
print(f'{"-" * 80}')
390+
391+
components = [
392+
("Encoder", encoder_params, encoder_trainable),
393+
("Transformer", transformer_params, transformer_trainable),
394+
]
395+
396+
if hasattr(self.world_model, 'head_dict'):
397+
components.append(("Prediction Heads", total_heads_params, total_heads_trainable))
398+
399+
for name, total, trainable in components:
400+
pct = 100 * total / total_params if total_params > 0 else 0
401+
print(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%')
402+
403+
print(f'{"=" * 80}')
404+
print(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}')
405+
print(f'{"=" * 80}\n')
406+
260407
#@profile
261408
def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput:
262409
"""

zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def generate_configs(env_id_list, env_configurations, collector_env_num, n_episo
124124
total_batch_size, num_layers, model_name, replay_ratio, norm_type):
125125
configs = []
126126
# ===== only for debug =====
127-
exp_name_prefix = f'data_lz/data_unizero_jericho_mt_20250513/jericho_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/'
127+
exp_name_prefix = f'data_scalezero/jericho_mt_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/'
128128

129129
action_space_size_list = [v[0] for _, v in env_configurations.items()]
130130
max_steps_list = [v[1] for _, v in env_configurations.items()]
@@ -184,7 +184,7 @@ def create_env_manager():
184184
# Model name or path - configurable according to the predefined model paths or names
185185
model_name: str = 'BAAI/bge-base-en-v1.5'
186186
replay_ratio = 0.1
187-
norm_type = 'BN'
187+
norm_type = 'LN'
188188

189189
collector_env_num = 4
190190
n_episode = 4

0 commit comments

Comments
 (0)