| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +# Usage: python -m apps.trainer.main --config apps/trainer/trainer_config.yaml  | 
 | 8 | + | 
 | 9 | +import asyncio  | 
 | 10 | + | 
 | 11 | +import torch  | 
 | 12 | +import torch.nn.functional as F  | 
 | 13 | + | 
 | 14 | +from forge.actors.trainer import RLTrainer  | 
 | 15 | +from forge.cli.config import parse  | 
 | 16 | +from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY  | 
 | 17 | +from forge.controller.provisioner import init_provisioner, shutdown  | 
 | 18 | +from forge.observability.metric_actors import get_or_create_metric_logger  | 
 | 19 | +from forge.types import (  | 
 | 20 | +    Launcher,  | 
 | 21 | +    LauncherConfig,  | 
 | 22 | +    ProcessConfig,  | 
 | 23 | +    ProvisionerConfig,  | 
 | 24 | +    ServiceConfig,  | 
 | 25 | +)  | 
 | 26 | +from omegaconf import DictConfig  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +def placeholder_loss_function(logits, targets):  | 
 | 30 | +    return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))  | 
 | 31 | + | 
 | 32 | + | 
 | 33 | +async def main(cfg: DictConfig):  | 
 | 34 | +    """Main function that only initializes the trainer."""  | 
 | 35 | + | 
 | 36 | +    # Initialize provisioner  | 
 | 37 | +    await init_provisioner(  | 
 | 38 | +        ProvisionerConfig(  | 
 | 39 | +            launcher_config=LauncherConfig(  | 
 | 40 | +                launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)),  | 
 | 41 | +                job_name=cfg.get(JOB_NAME_KEY, None),  | 
 | 42 | +                services={k: ServiceConfig(**v) for k, v in cfg.services.items()},  | 
 | 43 | +                actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},  | 
 | 44 | +            )  | 
 | 45 | +        )  | 
 | 46 | +    )  | 
 | 47 | + | 
 | 48 | +    # Initialize metric logging  | 
 | 49 | +    metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})  | 
 | 50 | +    mlogger = await get_or_create_metric_logger()  | 
 | 51 | +    await mlogger.init_backends.call_one(metric_logging_cfg)  | 
 | 52 | + | 
 | 53 | +    # Initialize trainer only  | 
 | 54 | +    print("Initializing trainer...")  | 
 | 55 | +    trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor(  | 
 | 56 | +        **cfg.trainer, loss=placeholder_loss_function  | 
 | 57 | +    )  | 
 | 58 | + | 
 | 59 | +    print("Trainer initialized successfully!")  | 
 | 60 | +    print(f"Trainer configuration: {cfg.trainer}")  | 
 | 61 | + | 
 | 62 | +    # Keep the trainer running for demonstration  | 
 | 63 | +    # In a real scenario, you might want to expose endpoints or do other work here  | 
 | 64 | +    try:  | 
 | 65 | +        print("Trainer is running. Press Ctrl+C to shutdown...")  | 
 | 66 | +        while True:  | 
 | 67 | +            await asyncio.sleep(1)  | 
 | 68 | +    except KeyboardInterrupt:  | 
 | 69 | +        print("Shutting down trainer...")  | 
 | 70 | +    finally:  | 
 | 71 | +        await RLTrainer.shutdown(trainer)  | 
 | 72 | +        await mlogger.shutdown.call_one()  | 
 | 73 | +        await shutdown()  | 
 | 74 | +        print("Trainer shutdown complete.")  | 
 | 75 | + | 
 | 76 | + | 
 | 77 | +if __name__ == "__main__":  | 
 | 78 | + | 
 | 79 | +    @parse  | 
 | 80 | +    def _main(cfg):  | 
 | 81 | +        asyncio.run(main(cfg))  | 
 | 82 | + | 
 | 83 | +    _main()  # @parse grabs the cfg from CLI  | 
0 commit comments