diff --git a/apps/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py similarity index 98% rename from apps/rl_trainer/main.py rename to tests/sandbox/rl_trainer/main.py index 8473cc16d..1441bb9e3 100644 --- a/apps/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Usage: python -m apps.rl_trainer.main --config apps/grpo/qwen3_32b.yaml +# Usage: python -m tests.sandbox.rl_trainer.main --config apps/grpo/qwen3_32b.yaml import asyncio diff --git a/apps/toy_rl/__init__.py b/tests/sandbox/toy_rl/__init__.py similarity index 100% rename from apps/toy_rl/__init__.py rename to tests/sandbox/toy_rl/__init__.py diff --git a/apps/toy_rl/sumdigits-tp.yaml b/tests/sandbox/toy_rl/sumdigits-tp.yaml similarity index 100% rename from apps/toy_rl/sumdigits-tp.yaml rename to tests/sandbox/toy_rl/sumdigits-tp.yaml diff --git a/apps/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py similarity index 96% rename from apps/toy_rl/sumdigits.py rename to tests/sandbox/toy_rl/sumdigits.py index 57971e9b9..14b5f6ebe 100644 --- a/apps/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Usage: python -m apps.toy_rl.sumdigits --config apps/toy_rl/sumdigits.yaml +# Usage: python -m tests.sandbox.toy_rl.sumdigits --config tests/sandbox/toy_rl/sumdigits.yaml import asyncio import random @@ -23,9 +23,10 @@ from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown - from forge.losses.grpo_loss import SimpleGRPOLoss -from forge.util.metric_logging import get_metric_logger +from forge.observability.metric_actors import get_or_create_metric_logger + +from forge.observability.metrics import record_metric, Reduce from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig @@ -220,7 +221,6 @@ def __init__(self, model_name, device: torch.device | None = None): self.model = AutoModelForCausalLM.from_pretrained( model_name, - dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) self.model.eval() @@ -267,7 +267,6 @@ async def setup(self): self.model = AutoModelForCausalLM.from_pretrained( self.model_name, - dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) self.model.train() @@ -463,15 +462,14 @@ async def main(cfg: DictConfig): max_res_tokens = cfg.max_res_tokens # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184 policy_tp_size = cfg.policy.engine_config.tensor_parallel_size - mlogger = get_metric_logger( - "wandb", - freq=1, - project="sumdigits-training", - ) # ---- Setup services ---- # print(f"{cfg.policy=}") print(f"{cfg.services.policy=}") + + metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + mlogger = await get_or_create_metric_logger() + await mlogger.init_backends.call_one(metric_logging_cfg) await ts.initialize() ( dataloader, @@ -533,9 +531,9 @@ async def continuous_rollouts(): avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) - mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) + record_metric("avg_response_len/rollout", avg_response_len, Reduce.MEAN) avg_reward = sum(e.reward for e in group.episodes) / group_size - mlogger.log("avg_reward/rollout", avg_reward, rollout_count) + record_metric("avg_reward/rollout", avg_reward, Reduce.MEAN) rollout_count += 1 @@ -550,7 +548,7 @@ async def continuous_training(): else: loss = await trainer.train_step.call_one(batch[0]) training_step += 1 - mlogger.log("loss/training_step", loss, training_step) + record_metric("loss/training_step", loss, Reduce.MEAN) print(f"loss/training_step: {loss} at training step {training_step}") await trainer.push_weights.call(training_step) await policy.update_weights.fanout(training_step) diff --git a/apps/toy_rl/sumdigits.yaml b/tests/sandbox/toy_rl/sumdigits.yaml similarity index 100% rename from apps/toy_rl/sumdigits.yaml rename to tests/sandbox/toy_rl/sumdigits.yaml diff --git a/apps/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py similarity index 100% rename from apps/toy_rl/toy_metrics/main.py rename to tests/sandbox/toy_rl/toy_metrics/main.py diff --git a/apps/vllm/deepseek_r1.yaml b/tests/sandbox/vllm/deepseek_r1.yaml similarity index 85% rename from apps/vllm/deepseek_r1.yaml rename to tests/sandbox/vllm/deepseek_r1.yaml index 7a0c2ad2d..252b20a3f 100644 --- a/apps/vllm/deepseek_r1.yaml +++ b/tests/sandbox/vllm/deepseek_r1.yaml @@ -1,4 +1,4 @@ -# >>> python -m apps.vllm.main --config apps/vllm/deepseek_r1.yaml +# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/deepseek_r1.yaml # NOTE - this won't work until we have proper HostMesh support policy: diff --git a/apps/vllm/llama3_8b.yaml b/tests/sandbox/vllm/llama3_8b.yaml similarity index 83% rename from apps/vllm/llama3_8b.yaml rename to tests/sandbox/vllm/llama3_8b.yaml index c4bc141bf..0e9a00607 100644 --- a/apps/vllm/llama3_8b.yaml +++ b/tests/sandbox/vllm/llama3_8b.yaml @@ -1,4 +1,4 @@ -# >>> python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml +# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml policy: engine_config: diff --git a/apps/vllm/main.py b/tests/sandbox/vllm/main.py similarity index 96% rename from apps/vllm/main.py rename to tests/sandbox/vllm/main.py index 3167817c7..b425af324 100644 --- a/apps/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -6,7 +6,7 @@ """To run: export HF_HUB_DISABLE_XET=1 -python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml +python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml """ import asyncio diff --git a/apps/vllm/qwen2_5_32b.yaml b/tests/sandbox/vllm/qwen2_5_32b.yaml similarity index 81% rename from apps/vllm/qwen2_5_32b.yaml rename to tests/sandbox/vllm/qwen2_5_32b.yaml index 72d55781b..3edfaa9d3 100644 --- a/apps/vllm/qwen2_5_32b.yaml +++ b/tests/sandbox/vllm/qwen2_5_32b.yaml @@ -1,4 +1,4 @@ -# >>> python -m apps.vllm.main --config apps/vllm/qwen2_5_32b.yaml +# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/qwen2_5_32b.yaml policy: engine_config: