Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/training/llama/finetune_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def train(model_id, tokenizer, dataset, training_args):
args = training_args.to_dict()

sft_config = NeuronSFTConfig(
max_seq_length=2048,
max_length=2048,
packing=True,
**args,
)
Expand All @@ -91,7 +91,7 @@ def train(model_id, tokenizer, dataset, training_args):
args=sft_config,
model=model,
peft_config=lora_config,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dataset,
formatting_func=lambda example: format_dolly(example, tokenizer),
)
Expand Down
4 changes: 2 additions & 2 deletions examples/training/qwen3/finetune_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train(model_id, tokenizer, dataset, training_args):
args = training_args.to_dict()

sft_config = NeuronSFTConfig(
max_seq_length=4096,
max_length=4096,
packing=True,
**args,
)
Expand All @@ -98,7 +98,7 @@ def formatting_function(examples):
args=sft_config,
model=model,
peft_config=lora_config,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dataset,
formatting_func=formatting_function,
)
Expand Down
3 changes: 2 additions & 1 deletion examples/training/qwen3/finetune_qwen3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ TP_DEGREE=8
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=2
MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name
# MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name
MODEL_NAME="Qwen/Qwen3-0.6B" # Change this to the desired model name
OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-finetuned"
DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE"
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
Expand Down
24 changes: 22 additions & 2 deletions optimum/neuron/trainers/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# Seg the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
Expand All @@ -32,4 +32,24 @@ def __init__(self, *args, **kwargs):

@dataclass
class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig):
pass
"""
Configuration class for Neuron-optimized SFT training.

Inherits from both NeuronTrainingArguments (for Trainium-specific settings) and
trl's SFTConfig (for SFT-specific settings).

Key Neuron-specific behavior:
- padding_free is always set to False to avoid recompilation on Trainium devices
- All other SFT parameters from trl 0.24.0+ are supported
"""

def __post_init__(self):
# Handle max_seq_length -> max_length migration for backward compatibility
if hasattr(self, "max_seq_length") and self.max_seq_length is not None:
self.max_length = self.max_seq_length

# Force padding_free to False for Neuron - critical for avoiding recompilation
# Neuron devices require fixed input shapes; padding_free flattening breaks this requirement
self.padding_free = False

super().__post_init__()
293 changes: 177 additions & 116 deletions optimum/neuron/trainers/sft_trainer.py

Large diffs are not rendered by default.

36 changes: 24 additions & 12 deletions optimum/neuron/trainers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,26 +936,29 @@ def get_batch_samples(

return batch_samples, num_items_in_batch

def train_step(
self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch: int | torch.Tensor | None = None
) -> torch.Tensor:
manager = self.autocast_smart_context_manager()

def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
):
if isinstance(model, NxDPPModel):
with manager:
loss = model.run_train(**inputs)
loss = model.run_train(**inputs)

# When using pipeline parallelism, the loss is only computed on the last stage.
# So we set the loss to zero on other stages.
if self.pp_rank != self.pp_size - 1:
dtype = torch.bfloat16 if self.args.bf16 else torch.float32
loss = torch.tensor(0, dtype=dtype).to(xm.xla_device())

# PP does not return any outputs except the loss
outputs = {"loss": loss}
else:
if num_items_in_batch is not None:
inputs = dict(**inputs, reduction="sum")

with manager:
outputs = model(**inputs)
outputs = model(**inputs)

if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
Expand All @@ -970,8 +973,17 @@ def train_step(
else:
loss = loss / self.args.gradient_accumulation_steps

# Backward pass
self.accelerator.backward(loss)
return (loss, outputs) if return_outputs else loss

def training_step(
self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch: int | torch.Tensor | None = None
) -> torch.Tensor:
manager = self.autocast_smart_context_manager()
with manager:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

# Backward pass
self.accelerator.backward(loss)

return loss

Expand Down Expand Up @@ -1102,7 +1114,7 @@ def train(
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

loss_step = self.train_step(self.model, inputs, num_items_in_batch=num_items_in_batch)
loss_step = self.training_step(self.model, inputs, num_items_in_batch=num_items_in_batch)
self.running_loss += loss_step.detach()

if do_sync_step:
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/trainers/trl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

TRL_VERSION = "0.11.4"
TRL_VERSION = "0.24.0"
2 changes: 0 additions & 2 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
"patch_within_function",
"replace_class_in_inheritance_hierarchy",
],
"trl_utils": ["NeuronSFTConfig", "NeuronORPOConfig"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -155,7 +154,6 @@
patch_within_function,
replace_class_in_inheritance_hierarchy,
)
from .trl_utils import NeuronORPOConfig, NeuronSFTConfig
else:
import sys

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ quality = [
"isort",
]
training = [
"trl == 0.11.4",
"trl == 0.23.1",
"peft == 0.17.0",
"evaluate == 0.4.3",
]
Expand Down
8 changes: 4 additions & 4 deletions tests/training/test_neuron_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def format_dolly(sample):
args = args.to_dict()
sft_config = NeuronSFTConfig(
# Using a small sequence-length since we are not validating the outputs.
max_seq_length=128,
max_length=128,
packing=packing,
dataset_num_proc=1,
**args,
Expand All @@ -86,7 +86,7 @@ def format_dolly(sample):
# Create Trainer instance
trainer = NeuronSFTTrainer(
model=model,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dataset,
formatting_func=format_dolly,
args=sft_config,
Expand Down Expand Up @@ -172,7 +172,7 @@ def format_dolly(sample):

args = args.to_dict()
sft_config = NeuronSFTConfig(
max_seq_length=128,
max_length=128,
packing=False, # No packing for PEFT test simplicity
dataset_num_proc=1,
**args,
Expand All @@ -181,7 +181,7 @@ def format_dolly(sample):
# Create SFT Trainer instance with PEFT model
trainer = NeuronSFTTrainer(
model=base_model,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=dataset,
formatting_func=format_dolly,
args=sft_config,
Expand Down
Loading