Skip to content

Commit 6cb5d66

Browse files
committed
Merge branch 'main' into phonebook
2 parents 5d3ff60 + 99d3937 commit 6cb5d66

File tree

4 files changed

+29
-123
lines changed

4 files changed

+29
-123
lines changed

lm_engine/containers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __setitem__(self, index: int, model: nn.Module) -> None:
2828
def __len__(self) -> int:
2929
return len(self.model_list)
3030

31-
def __str__(self):
31+
def __str__(self) -> str:
3232
return str(self.model_list)
3333

3434

lm_engine/data/debug.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,7 @@ def _get_example(self, token_id: int) -> dict:
6262
return example
6363

6464
def __getitem__(self, index: int) -> dict:
65-
if self._static_examples:
66-
example = self._example
67-
else:
68-
example = self._get_example(index)
69-
70-
return example
65+
return self._example if self._static_examples else self._get_example(index)
7166

7267
def __len__(self) -> int:
7368
return self._length

lm_engine/finetune.py

Lines changed: 4 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5-
from contextlib import AbstractContextManager, nullcontext
5+
from contextlib import nullcontext
66

77
import torch
88
from torch.distributed.tensor.parallel import loss_parallel
@@ -19,124 +19,18 @@
1919
from .kernels import enable_kernels
2020
from .model_wrapper import get_model_container
2121
from .optimization import get_learning_rate, get_optimizer_container, get_scheduler_container
22+
from .pretrain import train_step_without_pipeline_parallel
2223
from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics
2324
from .utils import (
2425
ExperimentsTracker,
2526
MetricsTrackingDict,
2627
ProcessGroupManager,
2728
StepTracker,
2829
init_distributed,
29-
is_torchao_available,
3030
setup_tf32,
3131
)
3232

3333

34-
if is_torchao_available():
35-
from .distributed import FP8Manager
36-
37-
38-
def train_step_without_pipeline_parallel(
39-
model_container: ModelContainer,
40-
optimizer_container: OptimizerContainer,
41-
lr_scheduler_container: LRSchedulerContainer,
42-
train_dataloader: ResumableDataLoader,
43-
gradient_clipping: float,
44-
forward_context: AbstractContextManager,
45-
backward_context: AbstractContextManager,
46-
sync_every_gradient_accumulation_step: bool,
47-
) -> MetricsTrackingDict:
48-
"""runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary
49-
50-
Args:
51-
model_container (ModelContainer): container of models
52-
optimizer_container (OptimizerContainer): container of optimizers
53-
lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers
54-
train_dataloader (ResumableDataLoader): training dataloader
55-
gradient_accumulation_steps (int): gradient accumulation steps
56-
gradient_clipping (float): gradient clipping value
57-
forward_context (AbstractContextManager): a context that is used for every model forward call
58-
backward_context (AbstractContextManager): a context that is used for every model backward call
59-
sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step
60-
61-
Returns:
62-
MetricsTrackingDict: metrics to track
63-
"""
64-
65-
model = model_container[0]
66-
67-
fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1
68-
69-
no_sync = nullcontext
70-
if not sync_every_gradient_accumulation_step:
71-
if fsdp_algorithm == 1:
72-
no_sync = model.no_sync
73-
else:
74-
model.set_requires_gradient_sync(False)
75-
76-
metrics_tracker = MetricsTrackingDict({})
77-
grad_norm = None
78-
optimizer_container.zero_grad()
79-
80-
gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps()
81-
82-
# note the effect of gradient accumulation division is already in the lm_loss_multiplier
83-
batches = [get_next_batch(train_dataloader) for _ in range(gradient_accumulation_steps)]
84-
lm_loss_multiplier = 1 / sum([(batch["labels"] != -100).sum() for batch in batches])
85-
86-
with no_sync():
87-
for batch in batches[:-1]:
88-
with forward_context():
89-
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)
90-
91-
# compute gradients
92-
with backward_context():
93-
loss_micro_step_dict["loss"].backward()
94-
95-
with torch.inference_mode():
96-
metrics_tracker = metrics_tracker + loss_micro_step_dict
97-
98-
if fsdp_algorithm == 2:
99-
model.set_requires_gradient_sync(True)
100-
101-
batch = batches[-1]
102-
with forward_context():
103-
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)
104-
105-
# compute gradients
106-
with backward_context():
107-
loss_micro_step_dict["loss"].backward()
108-
109-
with torch.inference_mode():
110-
metrics_tracker = metrics_tracker + loss_micro_step_dict
111-
112-
if gradient_clipping is not None:
113-
if fsdp_algorithm == 1:
114-
grad_norm = model.clip_grad_norm_(gradient_clipping)
115-
else:
116-
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
117-
118-
if is_torchao_available():
119-
FP8Manager.sync_float8_amax_and_scale_history([model])
120-
121-
optimizer_container.step()
122-
lr_scheduler_container.step()
123-
124-
if is_torchao_available():
125-
FP8Manager.precompute_float8_dynamic_scale_for_fsdp([model])
126-
127-
with torch.inference_mode():
128-
metrics_tracker["grad_norm"] = (
129-
torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm
130-
)
131-
132-
for key in metrics_tracker:
133-
metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key])
134-
135-
metrics_tracker = all_reduce_metrics_tracker(metrics_tracker)
136-
137-
return metrics_tracker
138-
139-
14034
def train(
14135
args: TrainingArgs,
14236
model_container: ModelContainer,
@@ -200,6 +94,8 @@ def train(
20094
forward_context=forward_context,
20195
backward_context=backward_context,
20296
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
97+
lm_loss_multiplier=None,
98+
tuning_method=args.tuning_args.tuning_method,
20399
)
204100

205101
metrics_tracker = metrics_tracker + loss_step_dict

lm_engine/pretrain.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def train_step_without_pipeline_parallel(
150150
backward_context: AbstractContextManager,
151151
sync_every_gradient_accumulation_step: bool,
152152
lm_loss_multiplier: float,
153+
tuning_method: TuningMethod,
153154
) -> MetricsTrackingDict:
154155
"""runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary
155156
@@ -163,11 +164,13 @@ def train_step_without_pipeline_parallel(
163164
backward_context (AbstractContextManager): a context that is used for every model backward call
164165
sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step
165166
lm_loss_multiplier (int): lm loss multiplier
167+
tuning_method (TuningMethod): tuning method for the current run
166168
167169
Returns:
168170
MetricsTrackingDict: metrics to track
169171
"""
170172

173+
assert len(model_container) == 1
171174
model = model_container[0]
172175

173176
fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1
@@ -185,31 +188,40 @@ def train_step_without_pipeline_parallel(
185188

186189
gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps()
187190

191+
if tuning_method == TuningMethod.full_finetuning:
192+
assert lm_loss_multiplier is None
193+
194+
# note the effect of gradient accumulation division is already in the lm_loss_multiplier
195+
batches = [get_next_batch(train_dataloader) for _ in range(gradient_accumulation_steps)]
196+
lm_loss_multiplier = gradient_accumulation_steps / sum([(batch["labels"] != -100).sum() for batch in batches])
197+
else:
198+
batches = None
199+
188200
with no_sync():
189-
for _ in range(gradient_accumulation_steps - 1):
190-
batch = get_next_batch(train_dataloader)
201+
for step in range(gradient_accumulation_steps - 1):
202+
batch = get_next_batch(train_dataloader) if batches is None else batches[step]
191203
with forward_context():
192204
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)
193205

194206
# compute gradients
195207
with backward_context():
196-
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
197-
loss_micro_step_scaled.backward()
208+
loss_micro_step: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
209+
loss_micro_step.backward()
198210

199211
with torch.inference_mode():
200212
metrics_tracker = metrics_tracker + loss_micro_step_dict
201213

202214
if fsdp_algorithm == 2:
203215
model.set_requires_gradient_sync(True)
204216

205-
batch = get_next_batch(train_dataloader)
217+
batch = get_next_batch(train_dataloader) if batches is None else batches[-1]
206218
with forward_context():
207219
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)
208220

209221
# compute gradients
210222
with backward_context():
211-
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
212-
loss_micro_step_scaled.backward()
223+
loss_micro_step: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
224+
loss_micro_step.backward()
213225

214226
with torch.inference_mode():
215227
metrics_tracker = metrics_tracker + loss_micro_step_dict
@@ -233,7 +245,9 @@ def train_step_without_pipeline_parallel(
233245
metrics_tracker = metrics_tracker / gradient_accumulation_steps
234246

235247
metrics_tracker["grad_norm"] = (
236-
torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm
248+
torch.zeros((1,), device=torch.cuda.current_device(), dtype=torch.float32)
249+
if grad_norm is None
250+
else grad_norm
237251
)
238252

239253
for key in metrics_tracker:
@@ -394,6 +408,7 @@ def train(
394408
backward_context=backward_context,
395409
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
396410
lm_loss_multiplier=1 / (micro_batch_size * sequence_length),
411+
tuning_method=args.tuning_args.tuning_method,
397412
)
398413

399414
metrics_tracker = metrics_tracker + loss_step_dict

0 commit comments

Comments
 (0)