|
2 | 2 | # Copyright (c) 2025, Mayank Mishra |
3 | 3 | # ************************************************** |
4 | 4 |
|
5 | | -from contextlib import AbstractContextManager, nullcontext |
| 5 | +from contextlib import nullcontext |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | from torch.distributed.tensor.parallel import loss_parallel |
|
19 | 19 | from .kernels import enable_kernels |
20 | 20 | from .model_wrapper import get_model_container |
21 | 21 | from .optimization import get_learning_rate, get_optimizer_container, get_scheduler_container |
| 22 | +from .pretrain import train_step_without_pipeline_parallel |
22 | 23 | from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics |
23 | 24 | from .utils import ( |
24 | 25 | ExperimentsTracker, |
25 | 26 | MetricsTrackingDict, |
26 | 27 | ProcessGroupManager, |
27 | 28 | StepTracker, |
28 | 29 | init_distributed, |
29 | | - is_torchao_available, |
30 | 30 | setup_tf32, |
31 | 31 | ) |
32 | 32 |
|
33 | 33 |
|
34 | | -if is_torchao_available(): |
35 | | - from .distributed import FP8Manager |
36 | | - |
37 | | - |
38 | | -def train_step_without_pipeline_parallel( |
39 | | - model_container: ModelContainer, |
40 | | - optimizer_container: OptimizerContainer, |
41 | | - lr_scheduler_container: LRSchedulerContainer, |
42 | | - train_dataloader: ResumableDataLoader, |
43 | | - gradient_clipping: float, |
44 | | - forward_context: AbstractContextManager, |
45 | | - backward_context: AbstractContextManager, |
46 | | - sync_every_gradient_accumulation_step: bool, |
47 | | -) -> MetricsTrackingDict: |
48 | | - """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary |
49 | | -
|
50 | | - Args: |
51 | | - model_container (ModelContainer): container of models |
52 | | - optimizer_container (OptimizerContainer): container of optimizers |
53 | | - lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers |
54 | | - train_dataloader (ResumableDataLoader): training dataloader |
55 | | - gradient_accumulation_steps (int): gradient accumulation steps |
56 | | - gradient_clipping (float): gradient clipping value |
57 | | - forward_context (AbstractContextManager): a context that is used for every model forward call |
58 | | - backward_context (AbstractContextManager): a context that is used for every model backward call |
59 | | - sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step |
60 | | -
|
61 | | - Returns: |
62 | | - MetricsTrackingDict: metrics to track |
63 | | - """ |
64 | | - |
65 | | - model = model_container[0] |
66 | | - |
67 | | - fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1 |
68 | | - |
69 | | - no_sync = nullcontext |
70 | | - if not sync_every_gradient_accumulation_step: |
71 | | - if fsdp_algorithm == 1: |
72 | | - no_sync = model.no_sync |
73 | | - else: |
74 | | - model.set_requires_gradient_sync(False) |
75 | | - |
76 | | - metrics_tracker = MetricsTrackingDict({}) |
77 | | - grad_norm = None |
78 | | - optimizer_container.zero_grad() |
79 | | - |
80 | | - gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps() |
81 | | - |
82 | | - # note the effect of gradient accumulation division is already in the lm_loss_multiplier |
83 | | - batches = [get_next_batch(train_dataloader) for _ in range(gradient_accumulation_steps)] |
84 | | - lm_loss_multiplier = 1 / sum([(batch["labels"] != -100).sum() for batch in batches]) |
85 | | - |
86 | | - with no_sync(): |
87 | | - for batch in batches[:-1]: |
88 | | - with forward_context(): |
89 | | - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) |
90 | | - |
91 | | - # compute gradients |
92 | | - with backward_context(): |
93 | | - loss_micro_step_dict["loss"].backward() |
94 | | - |
95 | | - with torch.inference_mode(): |
96 | | - metrics_tracker = metrics_tracker + loss_micro_step_dict |
97 | | - |
98 | | - if fsdp_algorithm == 2: |
99 | | - model.set_requires_gradient_sync(True) |
100 | | - |
101 | | - batch = batches[-1] |
102 | | - with forward_context(): |
103 | | - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) |
104 | | - |
105 | | - # compute gradients |
106 | | - with backward_context(): |
107 | | - loss_micro_step_dict["loss"].backward() |
108 | | - |
109 | | - with torch.inference_mode(): |
110 | | - metrics_tracker = metrics_tracker + loss_micro_step_dict |
111 | | - |
112 | | - if gradient_clipping is not None: |
113 | | - if fsdp_algorithm == 1: |
114 | | - grad_norm = model.clip_grad_norm_(gradient_clipping) |
115 | | - else: |
116 | | - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) |
117 | | - |
118 | | - if is_torchao_available(): |
119 | | - FP8Manager.sync_float8_amax_and_scale_history([model]) |
120 | | - |
121 | | - optimizer_container.step() |
122 | | - lr_scheduler_container.step() |
123 | | - |
124 | | - if is_torchao_available(): |
125 | | - FP8Manager.precompute_float8_dynamic_scale_for_fsdp([model]) |
126 | | - |
127 | | - with torch.inference_mode(): |
128 | | - metrics_tracker["grad_norm"] = ( |
129 | | - torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm |
130 | | - ) |
131 | | - |
132 | | - for key in metrics_tracker: |
133 | | - metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) |
134 | | - |
135 | | - metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) |
136 | | - |
137 | | - return metrics_tracker |
138 | | - |
139 | | - |
140 | 34 | def train( |
141 | 35 | args: TrainingArgs, |
142 | 36 | model_container: ModelContainer, |
@@ -200,6 +94,8 @@ def train( |
200 | 94 | forward_context=forward_context, |
201 | 95 | backward_context=backward_context, |
202 | 96 | sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, |
| 97 | + lm_loss_multiplier=None, |
| 98 | + tuning_method=args.tuning_args.tuning_method, |
203 | 99 | ) |
204 | 100 |
|
205 | 101 | metrics_tracker = metrics_tracker + loss_step_dict |
|
0 commit comments