Skip to content

Commit 3c13ec1

Browse files
committed
polish(pu): polish exp-name, add prompt_length log
1 parent 3a17c82 commit 3c13ec1

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

zoo/jericho/priorzero/models/actor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,24 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
263263

264264
self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
265265

266-
# Calculate response length statistics
266+
# Calculate response length statistics (action tokens)
267267
response_lengths = micro_batch['action_mask'].sum(dim=1).float()
268268
avg_response_length = response_lengths.mean().item()
269269
max_response_length = response_lengths.max().item()
270270
min_response_length = response_lengths.min().item()
271271

272+
# Calculate prompt length statistics (total - action tokens)
273+
total_lengths = micro_batch['attention_mask'].sum(dim=1).float()
274+
prompt_lengths = total_lengths - response_lengths
275+
avg_prompt_length = prompt_lengths.mean().item()
276+
max_prompt_length = prompt_lengths.max().item()
277+
min_prompt_length = prompt_lengths.min().item()
278+
279+
# Calculate total sequence length statistics
280+
avg_total_length = total_lengths.mean().item()
281+
max_total_length = total_lengths.max().item()
282+
min_total_length = total_lengths.min().item()
283+
272284
# Calculate log_probs statistics
273285
valid_log_probs = action_log_probs[micro_batch['action_mask'] > 0]
274286
avg_log_prob = valid_log_probs.mean().item() if valid_log_probs.numel() > 0 else 0.0
@@ -288,10 +300,18 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
288300
# "approx_kl": approx_kl.detach().float().mean().item(),
289301
"cur_old_kl": approx_kl.detach().float().mean().item(),
290302
"iter": self.train_iter,
291-
# Response length statistics
303+
# Response length statistics (action tokens)
292304
"response_length_avg": avg_response_length,
293305
"response_length_max": max_response_length,
294306
"response_length_min": min_response_length,
307+
# Prompt length statistics (context tokens)
308+
"prompt_length_avg": avg_prompt_length,
309+
"prompt_length_max": max_prompt_length,
310+
"prompt_length_min": min_prompt_length,
311+
# Total sequence length statistics
312+
"total_length_avg": avg_total_length,
313+
"total_length_max": max_total_length,
314+
"total_length_min": min_total_length,
295315
# Log prob and ratio statistics
296316
"log_prob_avg": avg_log_prob,
297317
"ratio_avg": avg_ratio,

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,40 @@ def get_priorzero_config(
412412
print(f" - Tensor Parallel Size: {llm_config.vllm_tensor_parallel_size}")
413413
print(f" - GPU Memory Utilization: {llm_config.gpu_memory_utilization}")
414414

415+
# Auto-generate exp_name with key configuration info if not provided
416+
if exp_name is None:
417+
# Extract key configuration parameters
418+
adv_type = llm_config.advantage_type
419+
adv_type_short = {
420+
'advantage': 'adv',
421+
'target_reward': 'tgt-rew',
422+
'advantage_batch_norm': 'adv-bn',
423+
'advantage_running_norm': 'adv-rn',
424+
}.get(adv_type, adv_type)
425+
426+
# Prior temperature schedule info
427+
prior_temp_cfg = llm_config.prior_temp_schedule
428+
if prior_temp_cfg.enable:
429+
prior_temp_str = f"pt-{prior_temp_cfg.schedule_type[:3]}-{prior_temp_cfg.init_temperature:.1f}to{prior_temp_cfg.final_temperature:.1f}"
430+
else:
431+
prior_temp_str = "pt-off"
432+
433+
# CoT info
434+
cot_str = "cot" if use_cot else "nocot"
435+
436+
# Format reward info
437+
fmt_rew_str = "fmt" if llm_config.reward_func.format_reward else "nofmt"
438+
439+
# Build exp_name
440+
exp_name = (
441+
f"data_priorzero/pz_{env_id}_{model_key}_"
442+
f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_seed{seed}"
443+
)
444+
445+
# Update config with generated exp_name
446+
main_config.exp_name = exp_name
447+
print(f"\n[Config] Auto-generated exp_name: {exp_name}\n")
448+
415449
return main_config, create_config, llm_config
416450

417451

zoo/jericho/priorzero/priorzero_entry_sync_ddp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def main():
350350
print(f"Quick Test: {args.quick_test}")
351351
print(f"{'='*80}\n")
352352

353-
# use_cot = True
353+
# use_cot = True
354354
if args.quick_test:
355355
logger.info("Using quick test configuration")
356356
main_cfg, create_cfg, llm_cfg = get_priorzero_debug_config(
@@ -359,9 +359,11 @@ def main():
359359
model_key=model_key,
360360
)
361361
else:
362+
# Generate exp_name with key configuration info
363+
# This will be called after get_priorzero_config, so we'll modify it there
362364
main_cfg, create_cfg, llm_cfg = get_priorzero_config(
363365
args.env_id, args.seed, use_cot=args.use_cot,
364-
exp_name=f'data_priorzero/priorzero_ddp_ppo_{args.env_id}_use_cot_{args.use_cot}_{model_key}_with_fmtReward_seed0',
366+
exp_name=None, # Will be auto-generated with config info
365367
model_key=model_key,
366368
multi_gpu=True
367369
)

0 commit comments

Comments
 (0)