Skip to content

Commit 2a0d1e3

Browse files
Update trl version to the latest release 0.11.4 -> 0.24.0 (#1000)
# What does this PR do? This PR adapts our `NeuronSFTTrainer` to update the version pinned for `trl` to the latest release `0.24.0` (latest release). Before that we were pinned to `0.11.4`. This PR should enable to work on trainers that appeared after `0.11.4` such as GRPO.
1 parent 1dfe4da commit 2a0d1e3

File tree

11 files changed

+354
-263
lines changed

11 files changed

+354
-263
lines changed

docs/source/training_tutorials/finetune_llama.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ lora_config = LoraConfig(
156156
args = training_args.to_dict()
157157
158158
sft_config = NeuronSFTConfig(
159-
max_seq_length=2048,
159+
max_length=2048,
160160
packing=True,
161161
**args,
162162
)
@@ -186,7 +186,7 @@ trainer = NeuronSFTTrainer(
186186
args=sft_config,
187187
model=model,
188188
peft_config=lora_config,
189-
tokenizer=tokenizer,
189+
processing_class=tokenizer,
190190
train_dataset=dataset,
191191
formatting_func=lambda example: format_dolly(example, tokenizer),
192192
)

docs/source/training_tutorials/finetune_qwen3.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ lora_config = LoraConfig(
164164
args = training_args.to_dict()
165165

166166
sft_config = NeuronSFTConfig(
167-
max_seq_length=4096,
167+
max_length=4096,
168168
packing=True,
169169
**args,
170170
)
@@ -181,7 +181,7 @@ dataset = preprocess_dataset_with_eos(tokenizer.eos_token)
181181
args=sft_config,
182182
model=model,
183183
peft_config=lora_config,
184-
tokenizer=tokenizer,
184+
processing_class=tokenizer,
185185
train_dataset=dataset,
186186
formatting_func=formatting_function,
187187
)

examples/training/llama/finetune_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def train(model_id, tokenizer, dataset, training_args):
8080
args = training_args.to_dict()
8181

8282
sft_config = NeuronSFTConfig(
83-
max_seq_length=2048,
83+
max_length=2048,
8484
packing=True,
8585
**args,
8686
)
@@ -91,7 +91,7 @@ def train(model_id, tokenizer, dataset, training_args):
9191
args=sft_config,
9292
model=model,
9393
peft_config=lora_config,
94-
tokenizer=tokenizer,
94+
processing_class=tokenizer,
9595
train_dataset=dataset,
9696
formatting_func=lambda example: format_dolly(example, tokenizer),
9797
)

examples/training/qwen3/finetune_qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def train(model_id, tokenizer, dataset, training_args):
8484
args = training_args.to_dict()
8585

8686
sft_config = NeuronSFTConfig(
87-
max_seq_length=4096,
87+
max_length=4096,
8888
packing=True,
8989
**args,
9090
)
@@ -98,7 +98,7 @@ def formatting_function(examples):
9898
args=sft_config,
9999
model=model,
100100
peft_config=lora_config,
101-
tokenizer=tokenizer,
101+
processing_class=tokenizer,
102102
train_dataset=dataset,
103103
formatting_func=formatting_function,
104104
)

optimum/neuron/trainers/sft_config.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Unless required by applicable law or agreed to in writing, software
1111
# distributed under the License is distributed on an "AS IS" BASIS,
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# Seg the License for the specific language governing permissions and
13+
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
@@ -32,4 +32,24 @@ def __init__(self, *args, **kwargs):
3232

3333
@dataclass
3434
class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig):
35-
pass
35+
"""
36+
Configuration class for Neuron-optimized SFT training.
37+
38+
Inherits from both NeuronTrainingArguments (for Trainium-specific settings) and
39+
trl's SFTConfig (for SFT-specific settings).
40+
41+
Key Neuron-specific behavior:
42+
- padding_free is always set to False to avoid recompilation on Trainium devices
43+
- All other SFT parameters from trl 0.24.0+ are supported
44+
"""
45+
46+
def __post_init__(self):
47+
# Handle max_seq_length -> max_length migration for backward compatibility
48+
if hasattr(self, "max_seq_length") and self.max_seq_length is not None:
49+
self.max_length = self.max_seq_length
50+
51+
# Force padding_free to False for Neuron - critical for avoiding recompilation
52+
# Neuron devices require fixed input shapes; padding_free flattening breaks this requirement
53+
self.padding_free = False
54+
55+
super().__post_init__()

0 commit comments

Comments
 (0)