Skip to content

Commit baeb35b

Browse files
author
Hossein Kavianihamedani
committed
Submitting an interactive notebook to run SFT
1 parent 3653453 commit baeb35b

File tree

9 files changed

+2124
-1466
lines changed

9 files changed

+2124
-1466
lines changed

apps/sft_v2/NOTEBOOK_GUIDE.md

Lines changed: 847 additions & 0 deletions
Large diffs are not rendered by default.

apps/sft_v2/README_NOTEBOOK.md

Lines changed: 0 additions & 435 deletions
This file was deleted.

apps/sft_v2/actor.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Abstract Actor class for training/inference actors in Forge.
9+
10+
This provides a base class that can be extended for different types of actors
11+
(e.g., Trainer, Evaluator, Inferencer, etc.)
12+
"""
13+
14+
import logging
15+
import math
16+
import os
17+
from abc import ABC, abstractmethod
18+
from typing import Any, Optional
19+
20+
import torch
21+
from forge.controller import ForgeActor
22+
from monarch.actor import current_rank, current_size
23+
from omegaconf import DictConfig, OmegaConf
24+
from torch import nn
25+
from torchtitan.components.loss import LossFunction
26+
from torchtitan.components.lr_scheduler import LRSchedulersContainer
27+
from torchtitan.components.optimizer import OptimizersContainer
28+
from torchtitan.distributed import ParallelDims
29+
from torchtitan.experiments.forge.engine import ForgeEngine
30+
from torchtitan.experiments.forge.job_config import ForgeJobConfig
31+
32+
Checkpointer = Any
33+
Dataloader = Any
34+
MetricLogger = Any
35+
Profiler = Any
36+
Tokenizer = Any
37+
38+
logger = logging.getLogger(__name__)
39+
logger.setLevel(logging.INFO)
40+
41+
42+
class BaseForgeActor(ForgeActor, ForgeEngine, ABC):
43+
"""
44+
Abstract base class for Forge actors.
45+
46+
This class handles common initialization, distributed setup, and provides
47+
abstract methods that must be implemented by concrete actor classes.
48+
"""
49+
50+
job_config: ForgeJobConfig
51+
parallel_dims: ParallelDims
52+
model: list[nn.Module]
53+
loss_fn: Optional[LossFunction]
54+
optimizer: Optional[OptimizersContainer]
55+
lr_scheduler: Optional[LRSchedulersContainer]
56+
checkpointer: Optional[Checkpointer]
57+
tokenizer: Optional[Tokenizer]
58+
metric_logger: Optional[MetricLogger]
59+
profiler: Optional[Profiler]
60+
device: torch.device
61+
62+
def __init__(self, config: DictConfig):
63+
"""
64+
Initialize the base actor with configuration.
65+
66+
Args:
67+
config: Configuration dictionary containing job settings
68+
"""
69+
job_config = ForgeJobConfig().to_dict()
70+
job_config = OmegaConf.merge(job_config, config)
71+
72+
self.current_step = 0
73+
self.metric_logger = None
74+
self.gradient_accumulation_steps = 1
75+
self._rank = current_rank().rank
76+
self._size = math.prod(current_size().values())
77+
78+
self._init_dist()
79+
super().__init__(job_config)
80+
81+
def _init_dist(self):
82+
"""
83+
Initialize torch distributed environment.
84+
85+
Sets up environment variables required for distributed training
86+
in the Monarch actor framework.
87+
"""
88+
env = {
89+
"RANK": str(self._rank),
90+
"LOCAL_RANK": str(self._rank),
91+
"LOCAL_WORLD_SIZE": str(self._size),
92+
"GROUP_RANK": str(self._size),
93+
"GROUP_WORLD_SIZE": str(self._size),
94+
"ROLE_RANK": str(self._rank),
95+
"ROLE_WORLD_SIZE": str(self._size),
96+
"ROLE_NAME": "rank",
97+
"WORLD_SIZE": str(self._size),
98+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
99+
}
100+
os.environ.update(env)
101+
logger.info(f"Initialized distributed environment: {env}")
102+
103+
@abstractmethod
104+
async def setup(self):
105+
"""
106+
Setup the actor (load data, checkpoint, etc.).
107+
108+
This method must be implemented by concrete actor classes.
109+
"""
110+
pass
111+
112+
@abstractmethod
113+
async def run(self):
114+
"""
115+
Main execution logic for the actor.
116+
117+
This method must be implemented by concrete actor classes.
118+
"""
119+
pass
120+
121+
@abstractmethod
122+
async def cleanup(self):
123+
"""
124+
Cleanup resources (close checkpointer, logger, etc.).
125+
126+
This method must be implemented by concrete actor classes.
127+
"""
128+
pass
129+
130+
@abstractmethod
131+
def __repr__(self) -> str:
132+
"""String representation of the actor."""
133+
pass

0 commit comments

Comments
 (0)