File tree Expand file tree Collapse file tree 7 files changed +27
-6
lines changed
Expand file tree Collapse file tree 7 files changed +27
-6
lines changed Original file line number Diff line number Diff line change @@ -482,6 +482,7 @@ def dpo_train(
482482 losses = train_results ["loss" ]
483483 metrics = {
484484 "loss" : train_results ["loss" ].numpy (),
485+ "grad_norm" : train_results ["grad_norm" ].numpy (),
485486 }
486487 metrics .update (train_results ["all_mb_metrics" ])
487488 for k , v in metrics .items ():
Original file line number Diff line number Diff line change @@ -576,6 +576,7 @@ def grpo_train(
576576 metrics = {
577577 "loss" : train_results ["loss" ].numpy (),
578578 "reward" : rewards .numpy (),
579+ "grad_norm" : train_results ["grad_norm" ].numpy (),
579580 }
580581 metrics .update (train_results ["all_mb_metrics" ])
581582 for k , v in metrics .items ():
Original file line number Diff line number Diff line change @@ -486,6 +486,7 @@ def sft_train(
486486 losses = train_results ["loss" ]
487487 metrics = {
488488 "loss" : train_results ["loss" ].numpy (),
489+ "grad_norm" : train_results ["grad_norm" ].numpy (),
489490 }
490491 metrics .update (train_results ["all_mb_metrics" ])
491492 for k , v in metrics .items ():
Original file line number Diff line number Diff line change @@ -369,6 +369,7 @@ def train(
369369 total_norm = grad_norm ,
370370 dtype = torch .float32 ,
371371 )
372+ grad_norm = torch .tensor ([grad_norm ])
372373
373374 # Update parameters
374375 self .optimizer .step ()
Original file line number Diff line number Diff line change @@ -310,10 +310,20 @@ def train(
310310 all_mb_metrics .append (loss_metrics )
311311
312312 # Clip gradients
313+ grad_norm = None
313314 if not eval_mode :
314- torch .nn .utils .clip_grad_norm_ (
315- self .model .parameters (), max_norm = self .cfg ["max_grad_norm" ]
316- )
315+ if isinstance (self .model , FullyShardedDataParallel ):
316+ # when using FSDP1, use FSDP's clip_grad_norm_
317+ # to ensure grad norm is being computed over all parameters
318+ # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
319+ grad_norm = self .model .clip_grad_norm_ (
320+ max_norm = self .cfg ["max_grad_norm" ]
321+ )
322+ else :
323+ grad_norm = torch .nn .utils .clip_grad_norm_ (
324+ self .model .parameters (), max_norm = self .cfg ["max_grad_norm" ]
325+ )
326+ grad_norm = grad_norm .cpu ()
317327
318328 # Update parameters
319329 self .optimizer .step ()
@@ -336,6 +346,7 @@ def train(
336346 metrics = {
337347 "global_loss" : global_loss .cpu (),
338348 "local_loss" : local_loss .cpu (),
349+ "grad_norm" : grad_norm ,
339350 "rank" : torch .distributed .get_rank (),
340351 "all_mb_metrics" : dict (mb_metrics ),
341352 }
Original file line number Diff line number Diff line change @@ -177,8 +177,10 @@ def train(
177177 results = self .worker_group .get_all_worker_results (futures )
178178
179179 # Aggregate the results
180- aggregated_results = {}
181- aggregated_results ["loss" ] = results [0 ]["global_loss" ]
180+ aggregated_results = {
181+ "loss" : results [0 ]["global_loss" ],
182+ "grad_norm" : results [0 ]["grad_norm" ],
183+ }
182184
183185 # Aggregate metrics across all workers
184186 all_mb_metrics = defaultdict (list )
Original file line number Diff line number Diff line change 2424def mock_components ():
2525 # Create mock components
2626 policy = MagicMock ()
27- policy .train .return_value = {"loss" : torch .tensor (0.5 ), "all_mb_metrics" : {}}
27+ policy .train .return_value = {
28+ "loss" : torch .tensor (0.5 ),
29+ "grad_norm" : torch .tensor (1.0 ),
30+ "all_mb_metrics" : {},
31+ }
2832
2933 # Create a proper message log structure with token_ids
3034 mock_batch = {
You can’t perform that action at this time.
0 commit comments