Skip to content

Commit 8d3afda

Browse files
committed
initial commit
1 parent 8cb21be commit 8d3afda

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

apps/trainer/main.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Usage: python -m apps.trainer.main --config apps/trainer/trainer_config.yaml
8+
9+
import asyncio
10+
11+
import torch
12+
import torch.nn.functional as F
13+
14+
from forge.actors.trainer import RLTrainer
15+
from forge.cli.config import parse
16+
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
17+
from forge.controller.provisioner import init_provisioner, shutdown
18+
from forge.observability.metric_actors import get_or_create_metric_logger
19+
from forge.types import (
20+
Launcher,
21+
LauncherConfig,
22+
ProcessConfig,
23+
ProvisionerConfig,
24+
ServiceConfig,
25+
)
26+
from omegaconf import DictConfig
27+
28+
29+
def placeholder_loss_function(logits, targets):
30+
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
31+
32+
33+
async def main(cfg: DictConfig):
34+
"""Main function that only initializes the trainer."""
35+
36+
# Initialize provisioner
37+
await init_provisioner(
38+
ProvisionerConfig(
39+
launcher_config=LauncherConfig(
40+
launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)),
41+
job_name=cfg.get(JOB_NAME_KEY, None),
42+
services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
43+
actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
44+
)
45+
)
46+
)
47+
48+
# Initialize metric logging
49+
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
50+
mlogger = await get_or_create_metric_logger()
51+
await mlogger.init_backends.call_one(metric_logging_cfg)
52+
53+
# Initialize trainer only
54+
print("Initializing trainer...")
55+
trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor(
56+
**cfg.trainer, loss=placeholder_loss_function
57+
)
58+
59+
print("Trainer initialized successfully!")
60+
print(f"Trainer configuration: {cfg.trainer}")
61+
62+
# Keep the trainer running for demonstration
63+
# In a real scenario, you might want to expose endpoints or do other work here
64+
try:
65+
print("Trainer is running. Press Ctrl+C to shutdown...")
66+
while True:
67+
await asyncio.sleep(1)
68+
except KeyboardInterrupt:
69+
print("Shutting down trainer...")
70+
finally:
71+
await RLTrainer.shutdown(trainer)
72+
await mlogger.shutdown.call_one()
73+
await shutdown()
74+
print("Trainer shutdown complete.")
75+
76+
77+
if __name__ == "__main__":
78+
79+
@parse
80+
def _main(cfg):
81+
asyncio.run(main(cfg))
82+
83+
_main() # @parse grabs the cfg from CLI

0 commit comments

Comments
 (0)