Skip to content

Commit d190278

Browse files
authored
Garbage collect on every train / ref model step (#209)
1 parent 791cb26 commit d190278

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

apps/grpo/qwen3_1_7b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ trainer:
4848
max_norm: 1.0
4949
steps: 1000000
5050
dtype: bfloat16
51+
gc_freq: 1
5152
compile:
5253
enable: false
5354
parallelism:
@@ -83,6 +84,7 @@ ref_model:
8384
hf_assets_path: hf://${model}
8485
training:
8586
dtype: bfloat16
87+
gc_freq: 1
8688
compile:
8789
enable: false
8890
parallelism:

src/forge/actors/reference_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ReferenceModel(ForgeActor):
3737
compile: Compile = field(default_factory=Compile)
3838
training: Training = field(
3939
default_factory=Training
40-
) # Only needed in order to correctly set a lower dtype
40+
) # Needed in order to set attrs like dtype, garbage collection freq, etc.
4141

4242
# Populated in setup
4343
# TODO: Commented out since engine_config parsing extracts from class members
@@ -61,6 +61,7 @@ def __post_init__(self):
6161
"""
6262
self.rank = current_rank().rank
6363
self.size = math.prod(current_size().values())
64+
self.step = 0
6465

6566
env = {
6667
"RANK": str(self.rank),
@@ -83,6 +84,7 @@ async def setup(self):
8384

8485
@endpoint
8586
async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
87+
self.engine.gc_handler.run(self.step)
8688
model_parts = self.engine.model_parts
8789
parallel_dims = self.engine.parallel_dims
8890
input_ids = input_ids.to("cuda")
@@ -106,6 +108,7 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
106108
with self.engine.maybe_enable_amp:
107109
with torch.inference_mode():
108110
logits = model_parts[0](input_ids)
111+
self.step += 1
109112
if isinstance(logits, DTensor):
110113
logits = logits.full_tensor()
111114
return logits

src/forge/actors/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __post_init__(self):
7373
f"{f.name} should be a {f.type} type or a dict like object"
7474
)
7575

76-
self.current_step = 1 # fragile contract.
76+
self.step = 1 # fragile contract.
7777
self.num_training_steps = self.training.steps
7878
self.gradient_accumulation_steps = 1
7979
self.rank = current_rank().rank
@@ -100,7 +100,7 @@ async def setup(self):
100100
for key in {"loss", "state_dict_key", "use_dcp"}:
101101
engine_config.pop(key) # Not part of job config
102102
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
103-
self.engine.checkpointer.load(step=self.current_step)
103+
self.engine.checkpointer.load(step=self.step)
104104
self.engine.optimizers.zero_grad()
105105

106106
def forward_backward(
@@ -173,6 +173,7 @@ def forward_backward(
173173
def train_step(
174174
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
175175
) -> float:
176+
self.engine.gc_handler.run(self.step)
176177
local_inputs = inputs[self.engine.dp_rank]
177178
local_targets = targets[self.engine.dp_rank]
178179
batch_to_device(local_inputs, self.engine.device)
@@ -192,10 +193,10 @@ def train_step(
192193
self.engine.optimizers.zero_grad()
193194
self.engine.lr_schedulers.step()
194195

195-
self.current_step += 1
196+
self.step += 1
196197
self.engine.checkpointer.save(
197-
curr_step=self.current_step,
198-
last_step=self.current_step == self.num_training_steps,
198+
curr_step=self.step,
199+
last_step=self.step == self.num_training_steps,
199200
)
200201

201202
return loss.item()

0 commit comments

Comments
 (0)