Skip to content

Commit 4934f4c

Browse files
MaxusmustiNikhilNayak-debugRobotSail
authored
Adding GPT OSS Support (#646)
* Adding dequantized load support for gpt_oss models Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Update gpt oss saving with requantization Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Adjust data processing for gpt format Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * fix for exact quantization algorithm to replicate OpenAI quantized weights * Speedup replicate implementation Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * router freezing for gpt oss Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Add corrected loss, aux loss support, and batching updates Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Cleanup unnecessary test files Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Fix linting and review feedback Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Add linting skip for mxfp4 import Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Fix unit tests with mock configs Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * Switch to mini trainer sampler Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> * remove dead code + add defaults (#653) * Refactors train loop, adds padded batch packer, other fixes (#654) * addition of padded batch packer + simplified train loop * update tests + linting * fix tests x2 --------- Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com> Co-authored-by: Nikhil Nayak <nikhilnayak268@gmail.com> Co-authored-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
1 parent 2a02acc commit 4934f4c

21 files changed

+2962
-901
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ wheel>=0.43
33
pyyaml
44
py-cpuinfo
55
torch>=2.6.0
6-
transformers>=4.45.2
6+
transformers>=4.55.0
77

88
datasets>=2.15.0
99
numba

src/instructlab/training/accelerator.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
# Standard
22
from copy import deepcopy
3-
from typing import Callable, Optional
3+
from functools import partial
4+
from typing import Optional
5+
import logging
46

57
# Third Party
68
from accelerate import Accelerator as TransformersAccel
9+
from accelerate.utils import DeepSpeedPlugin, FullyShardedDataParallelPlugin
10+
from peft.utils.other import fsdp_auto_wrap_policy
11+
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
12+
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
13+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
714
from torch.utils.data import DataLoader
815
from transformers import get_scheduler
916
import torch
@@ -13,10 +20,13 @@
1320
DeepSpeedOptions,
1421
DistributedBackend,
1522
)
23+
from instructlab.training.utils import get_module_class_from_name
1624

1725
# Local
1826
from .model import Model
1927

28+
logger = logging.getLogger(__name__)
29+
2030

2131
class Accelerator:
2232
def __init__(
@@ -32,6 +42,7 @@ def __init__(
3242
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
3343
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
3444
fsdp_cpu_offload_params: Optional[bool] = False,
45+
fsdp_use_orig_params: Optional[bool] = False,
3546
):
3647
self.samples_per_gpu = samples_per_gpu
3748
self.save_samples = save_samples
@@ -48,7 +59,8 @@ def __init__(
4859
deepspeed_cpu_offload_optimizer_ratio
4960
)
5061
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
51-
62+
self.fsdp_use_orig_params = fsdp_use_orig_params
63+
self.lr_scheduler = None
5264
if self.distributed_framework == DistributedBackend.DEEPSPEED:
5365
# Standard
5466
accel_args = {
@@ -84,7 +96,6 @@ def prepare_with_optimizer(
8496
num_epochs: int,
8597
num_warmup_steps: int,
8698
):
87-
self.lr_scheduler: Callable
8899
self.setup_lr_scheduler(
89100
optimizer=optimizer,
90101
lr_scheduler=lr_scheduler,
@@ -120,19 +131,6 @@ def __getattr__(self, name):
120131
return getattr(self.accelerator, name)
121132

122133
def get_fsdp_config(self):
123-
# Standard
124-
from functools import partial
125-
126-
# Third Party
127-
from accelerate.utils import FullyShardedDataParallelPlugin
128-
from peft.utils.other import fsdp_auto_wrap_policy
129-
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
130-
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
131-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
132-
133-
# First Party
134-
from instructlab.training.utils import get_module_class_from_name
135-
136134
is_lora = self.model.lora_config is not None
137135
block_name = self.model._no_split_modules[0]
138136

@@ -158,20 +156,16 @@ def get_fsdp_config(self):
158156
backward_prefetch=prefetch_policy,
159157
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
160158
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
159+
use_orig_params=self.fsdp_use_orig_params,
160+
# TODO(osilkin): expose switch for fp32 reduction
161161
)
162162

163-
# `use_orig_params` must be disabled when using LoRA and FSDP together
164-
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
165-
if self.model.lora_config is not None:
166-
fsdp_plugin.use_orig_params = False
167-
168163
return fsdp_plugin
169164

170165
def get_ds_plugin(
171166
self, world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions
172167
):
173168
# Third Party
174-
from accelerate.utils import DeepSpeedPlugin
175169

176170
ds_config = {
177171
"train_batch_size": samples_per_gpu * world_size * grad_accum,
@@ -248,3 +242,12 @@ def setup_fsdp(
248242
fsdp_cpu_offload_params=fsdp_cpu_offload_params,
249243
save_samples=save_samples,
250244
)
245+
246+
def take_optimizer_step(self):
247+
"""
248+
Take an optimizer step and update the learning rate scheduler.
249+
"""
250+
self.clip_grad_norm_(self.model.parameters(), 1.0)
251+
self.optimizer.step()
252+
self.lr_scheduler.step()
253+
self.optimizer.zero_grad()
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Batch loss management for distributed training.
4+
5+
This module provides utilities for managing loss computation, accumulation,
6+
and reduction across distributed training environments.
7+
"""
8+
9+
# Standard
10+
from dataclasses import dataclass
11+
import logging
12+
13+
# Third Party
14+
import torch
15+
import torch.distributed
16+
17+
# First Party
18+
from instructlab.training.accelerator import Accelerator
19+
from instructlab.training.model import Model
20+
from instructlab.training.type_definitions import CollatedItem, ModelInputs
21+
22+
logger = logging.getLogger("instructlab.training")
23+
24+
25+
@dataclass
26+
class BatchMetrics:
27+
"""Metrics collected during batch processing."""
28+
29+
total_samples: int
30+
total_length: int
31+
num_loss_counted_tokens: int
32+
accumulated_loss: torch.Tensor
33+
accumulated_aux_loss: torch.Tensor | None
34+
grad_accum_steps: int
35+
num_minibatches: int
36+
37+
38+
class BatchLossManager:
39+
"""
40+
Manages loss computation and metrics collection for batches in distributed training.
41+
42+
This class handles:
43+
- Processing minibatches within a batch
44+
- Accumulating losses across minibatches
45+
- Reducing metrics across distributed ranks
46+
- Computing average losses for logging
47+
"""
48+
49+
def __init__(self, model, accelerator, world_size: int, local_rank: int):
50+
"""
51+
Initialize the BatchLossManager.
52+
53+
Args:
54+
model: The model used for training
55+
accelerator: The accelerator instance for distributed training
56+
world_size: Number of distributed processes
57+
local_rank: Local rank of the current process
58+
"""
59+
self.model: Model = model
60+
self.accelerator: Accelerator = accelerator
61+
self.world_size: int = world_size
62+
self.local_rank: int = local_rank
63+
self.torch_device = torch.device("cuda", local_rank)
64+
65+
def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
66+
"""
67+
Process a batch of minibatches, computing losses and accumulating gradients.
68+
69+
Args:
70+
batch: List of minibatches to process
71+
72+
Returns:
73+
tuple: (BatchMetrics, average_loss_across_ranks)
74+
"""
75+
# extract batch-level info (same across all minibatches)
76+
batch_num_loss_counted_tokens = batch[0]["batch_num_loss_counted_tokens"]
77+
num_minibatches = len(batch)
78+
79+
# initialize accumulation variables
80+
batch_total_samples = 0
81+
batch_total_length = 0
82+
accumulated_loss = 0.0
83+
accumulated_aux_loss = 0.0
84+
grad_accum_steps = 0
85+
86+
# process each minibatch
87+
for mb in batch:
88+
# extract minibatch-specific info
89+
micro_batch_size = mb["num_samples"]
90+
total_length = mb["total_length"]
91+
92+
# accumulate minibatch metrics
93+
batch_total_samples += micro_batch_size
94+
batch_total_length += total_length
95+
96+
# prepare model inputs
97+
model_inputs = self._prepare_model_inputs(mb)
98+
99+
# compute loss and backward pass
100+
scaled_loss, raw_losses = self.model.compute_loss(
101+
model_inputs, self.world_size, batch_num_loss_counted_tokens
102+
)
103+
self.accelerator.backward(scaled_loss)
104+
105+
# accumulate losses
106+
grad_accum_steps += 1
107+
accumulated_loss += raw_losses.main_loss
108+
if raw_losses.aux_loss is not None:
109+
accumulated_aux_loss += raw_losses.aux_loss
110+
111+
# reduce metrics across ranks
112+
batch_total_samples, batch_total_length = self._reduce_metrics(
113+
batch_total_samples, batch_total_length
114+
)
115+
116+
# calculate average loss across all ranks
117+
avg_loss_across_ranks = self._compute_average_loss(
118+
accumulated_loss, accumulated_aux_loss, batch_num_loss_counted_tokens
119+
)
120+
121+
# create metrics object
122+
metrics = BatchMetrics(
123+
total_samples=int(batch_total_samples),
124+
total_length=int(batch_total_length),
125+
num_loss_counted_tokens=int(batch_num_loss_counted_tokens),
126+
accumulated_loss=accumulated_loss,
127+
accumulated_aux_loss=accumulated_aux_loss,
128+
grad_accum_steps=grad_accum_steps,
129+
num_minibatches=num_minibatches,
130+
)
131+
132+
return metrics, avg_loss_across_ranks
133+
134+
def _prepare_model_inputs(self, mb: CollatedItem) -> ModelInputs:
135+
"""Prepare and move model inputs to GPU."""
136+
model_inputs = ModelInputs(
137+
input_ids=mb["input_ids"].to(device=self.torch_device),
138+
labels=mb["labels"].to(device=self.torch_device),
139+
)
140+
141+
# add optional fields onto `model_inputs` object
142+
if "attention_mask" in mb:
143+
model_inputs["attention_mask"] = mb["attention_mask"].to(
144+
device=self.torch_device
145+
)
146+
if "position_ids" in mb:
147+
model_inputs["position_ids"] = mb["position_ids"].to(
148+
device=self.torch_device
149+
)
150+
151+
return model_inputs
152+
153+
def _reduce_metrics(
154+
self, batch_total_samples: int, batch_total_length: int
155+
) -> tuple[int, int]:
156+
"""Reduce rank-specific metrics across devices."""
157+
inputs_to_reduce = torch.tensor(
158+
[batch_total_samples, batch_total_length],
159+
dtype=torch.int32,
160+
device=self.accelerator.device,
161+
)
162+
163+
reduced_outputs = self.accelerator.reduce(inputs_to_reduce, reduction="sum")
164+
return reduced_outputs[0].item(), reduced_outputs[1].item()
165+
166+
def _compute_average_loss(
167+
self,
168+
accumulated_loss: torch.Tensor,
169+
accumulated_aux_loss: torch.Tensor | None,
170+
batch_num_loss_counted_tokens: int,
171+
) -> float:
172+
"""Compute average loss across all ranks for metrics logging."""
173+
# calculate total batch loss
174+
total_batch_loss = (
175+
accumulated_loss * self.world_size / batch_num_loss_counted_tokens
176+
)
177+
if self.model.is_gpt_oss and accumulated_aux_loss is not None:
178+
total_batch_loss += accumulated_aux_loss
179+
180+
# reduce across ranks
181+
avg_loss_across_ranks = self.accelerator.reduce(
182+
torch.tensor(
183+
total_batch_loss.detach().item(), device=self.accelerator.device
184+
),
185+
reduction="mean",
186+
).item()
187+
188+
return avg_loss_across_ranks

0 commit comments

Comments
 (0)