Skip to content
Merged
12 changes: 10 additions & 2 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,10 @@ def distillation_train(
print("▶ Computing teacher logprobs...", flush=True)
with timer.time("teacher_logprob_inference"):
teacher_topk = teacher_policy.get_topk_logits(
train_data, k=master_config["distillation"]["topk_logits_k"]
train_data,
k=master_config["distillation"]["topk_logits_k"],
timer=timer,
timer_tag_prefix="teacher_logprob_inference",
)
train_data["teacher_topk_logits"] = teacher_topk["topk_logits"]
train_data["teacher_topk_indices"] = teacher_topk["topk_indices"]
Expand All @@ -708,7 +711,12 @@ def distillation_train(

print("▶ Training policy...", flush=True)
with timer.time("policy_training"):
train_results = student_policy.train(train_data, loss_fn)
train_results = student_policy.train(
train_data,
loss_fn,
timer=timer,
timer_tag_prefix="policy_training",
)

is_last_step = (total_steps + 1 >= max_steps) or (
(current_epoch + 1 == max_epochs)
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ def dpo_train(
## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
gbs=master_config["policy"]["train_global_batch_size"] * 2,
mbs=master_config["policy"]["train_micro_batch_size"] * 2,
timer=timer,
timer_tag_prefix="policy_training",
)

is_last_step = total_steps + 1 >= master_config["dpo"][
Expand Down
38 changes: 28 additions & 10 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,17 +1444,19 @@ def grpo_train(
**extra_multimodal_data,
}
)
train_data["prev_logprobs"] = policy.get_logprobs(logprob_data)[
"logprobs"
]
train_data["prev_logprobs"] = policy.get_logprobs(
logprob_data, timer=timer, timer_tag_prefix="policy_logprobs"
)["logprobs"]

if not master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation"
):
train_data["reference_policy_logprobs"] = (
policy.get_reference_policy_logprobs(logprob_data)[
"reference_logprobs"
]
policy.get_reference_policy_logprobs(
logprob_data,
timer=timer,
timer_tag_prefix="reference_logprobs",
)["reference_logprobs"]
)

del logprob_data
Expand All @@ -1468,7 +1470,12 @@ def grpo_train(

print("▶ Training policy...", flush=True)
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)
train_results = policy.train(
train_data,
loss_fn,
timer=timer,
timer_tag_prefix="policy_training",
)

# Recompute KV scales after policy training if needed
if sync_kv_scales:
Expand Down Expand Up @@ -2429,9 +2436,15 @@ def async_grpo_train(

print("▶ Computing logprobs...")
with timer.time("policy_and_reference_logprobs"):
fprop_logprobs = policy.get_logprobs(train_data)["logprobs"]
fprop_logprobs = policy.get_logprobs(
train_data,
timer=timer,
timer_tag_prefix="policy_logprob_inference",
)["logprobs"]
reference_logprobs = policy.get_reference_policy_logprobs(
train_data
train_data,
timer=timer,
timer_tag_prefix="reference_logprob_inference",
)["reference_logprobs"]
train_data["prev_logprobs"] = fprop_logprobs
train_data["reference_policy_logprobs"] = reference_logprobs
Expand All @@ -2443,7 +2456,12 @@ def async_grpo_train(

print("▶ Training policy...")
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)
train_results = policy.train(
train_data,
loss_fn,
timer=timer,
timer_tag_prefix="policy_training",
)

print("🔄 Synchronizing policy weights to trajectory collector…")
vllm_logger_metrics = None
Expand Down
24 changes: 16 additions & 8 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,16 @@ def validate(
val_data = maybe_pad_last_batch(val_data, dp_size, val_mbs)

## just run model fwd
val_results = policy.train(
val_data,
loss_fn,
eval_mode=True,
gbs=val_data.size,
mbs=val_mbs,
)
with timer.time("policy_training"):
val_results = policy.train(
val_data,
loss_fn,
eval_mode=True,
gbs=val_data.size,
mbs=val_mbs,
timer=timer,
timer_tag_prefix="policy_training",
)

if len(val_results["all_mb_metrics"]) == 0:
warnings.warn(
Expand Down Expand Up @@ -452,7 +455,12 @@ def sft_train(

print("▶ Taking a training step...")
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)
train_results = policy.train(
train_data,
loss_fn,
timer=timer,
timer_tag_prefix="policy_training",
)

is_last_step = total_steps + 1 >= master_config["sft"][
"max_num_steps"
Expand Down
6 changes: 5 additions & 1 deletion nemo_rl/models/policy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ class PolicyInterface(ABC):

@abstractmethod
def get_logprobs(
self, data: BatchedDataDict[GenerationDatumSpec]
self,
data: BatchedDataDict[GenerationDatumSpec],
**kwargs: Any,
) -> BatchedDataDict[LogprobOutputSpec]:
"""Get logprobs of actions from observations.

Expand All @@ -70,6 +72,7 @@ def get_reference_policy_logprobs(
self,
data: BatchedDataDict[GenerationDatumSpec],
micro_batch_size: Optional[int] = None,
**kwargs: Any,
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
"""Get logprobs of actions from observations.

Expand Down Expand Up @@ -105,6 +108,7 @@ def train(
*,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Train the policy on a global batch of data.

Expand Down
Loading
Loading