diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1dbef0b76..c7c5c0ce1 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -30,7 +30,8 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger -from forge.observability.metrics import record_metric, Reduce + +# from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig @@ -161,36 +162,36 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl reward_fn_name = getattr( reward_fn, "__name__", reward_fn.__class__.__name__ ) - # per function reward - record_metric( - f"reward/evaluate_response/sum_{reward_fn_name}_reward", - reward, - Reduce.SUM, - ) - record_metric( - f"reward/evaluate_response/avg_{reward_fn_name}_reward", - reward, - Reduce.MEAN, - ) - record_metric( - f"reward/evaluate_response/std_{reward_fn_name}_reward", - reward, - Reduce.STD, - ) - - # avg total reward - record_metric( - "reward/evaluate_response/avg_total_reward", - reward, - Reduce.MEAN, - ) - - # count fn calls - record_metric( - f"reward/evaluate_response/count_{reward_fn_name}_calls", - 1, - Reduce.SUM, - ) + # # per function reward + # record_metric( + # f"reward/evaluate_response/sum_{reward_fn_name}_reward", + # reward, + # Reduce.SUM, + # ) + # record_metric( + # f"reward/evaluate_response/avg_{reward_fn_name}_reward", + # reward, + # Reduce.MEAN, + # ) + # record_metric( + # f"reward/evaluate_response/std_{reward_fn_name}_reward", + # reward, + # Reduce.STD, + # ) + + # # avg total reward + # record_metric( + # "reward/evaluate_response/avg_total_reward", + # reward, + # Reduce.MEAN, + # ) + + # # count fn calls + # record_metric( + # f"reward/evaluate_response/count_{reward_fn_name}_calls", + # 1, + # Reduce.SUM, + # ) avg_reward = total_rewards / len(self.reward_functions) return avg_reward @@ -256,12 +257,12 @@ async def sample(self) -> dict[str, str] | None: sample = next(self._iterator) # Record dataset metrics - record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) - record_metric( - "dataset/sample/avg_sample_len", - len(sample["request"]), - Reduce.MEAN, - ) + # record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + # record_metric( + # "dataset/sample/avg_sample_len", + # len(sample["request"]), + # Reduce.MEAN, + # ) return sample except StopIteration: @@ -304,9 +305,9 @@ async def main(cfg: DictConfig): else: provisioner = await init_provisioner() - metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() - await mlogger.init_backends.call_one(metric_logging_cfg) + # metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + # mlogger = await get_or_create_metric_logger() + # await mlogger.init_backends.call_one(metric_logging_cfg) # ---- Setup services ---- # @@ -414,9 +415,9 @@ async def continuous_rollouts(): # Log metrics rollout_count += 1 - record_metric( - "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM - ) + # record_metric( + # "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM + # ) t.stop() async def continuous_training(): @@ -458,7 +459,7 @@ async def continuous_training(): restart_tracer = True # Flush metrics every training step to WandB - await mlogger.flush.call_one(training_step) + # await mlogger.flush.call_one(training_step) print( f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 14e4871cf..2d0e7cc94 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -13,14 +13,14 @@ off_by_n: 1 # Off by one by default rollout_threads: 1 # Recommended to set equal to policy.num_replicas -# Observability configuration -metric_logging: - wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True - console: - reduce_across_ranks: True +# # Observability configuration +# metric_logging: +# wandb: +# project: "grpo-training" +# group: "grpo_exp_${oc.env:USER}" +# reduce_across_ranks: True +# console: +# reduce_across_ranks: True # Dataset configuration dataset: