Skip to content

Commit 98333e6

Browse files
authored
upgrade trl to 0.24.0 and liger to 0.6.3 (#3230)
* upgrade trl to 0.24.0 * fix reward collator init * use newer DataCollatorForPreference instead * DataCollatorForPreference doesn't use padding kwarg * fix input id labels * fix fbgemm-gpu version for pytorch versions * tweak pinned deps * transformers doesn't support hub 1.0 yet * upgrade liger dep to 0.6.3 * set TORCH_CUDA_ARCH_LIST correctly
1 parent 9d4d39e commit 98333e6

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

cicd/Dockerfile.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
22

3-
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
3+
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
44
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
55
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
66
ENV CUDA="{{ CUDA }}"

requirements.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@ bitsandbytes==0.47.0
55
triton>=3.0.0
66
mamba-ssm==1.2.0.post1
77
xformers>=0.0.23.post1
8-
liger-kernel==0.6.1
8+
liger-kernel==0.6.3
99
# END section
1010

1111
packaging==23.2
1212

13-
huggingface_hub>=0.33.0
13+
huggingface_hub>=0.36.0
1414
peft>=0.17.1
1515
tokenizers>=0.21.1
1616
transformers==4.57.1
1717
accelerate==1.10.1
1818
datasets==4.0.0
1919
deepspeed>=0.17.0
20-
trl==0.23.1
21-
hf_xet==1.1.5
22-
kernels==0.9.0
20+
trl==0.24.0
21+
hf_xet==1.2.0
22+
kernels>=0.9.0
2323
trackio
2424

2525
optimum==1.16.2
2626
hf_transfer
2727
sentencepiece
28-
gradio==5.41.1
28+
gradio==5.49.1
2929

3030
modal==1.0.2
3131
pydantic==2.10.6

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ def parse_requirements(extras_require_map):
6262
else:
6363
raise ValueError("Invalid version format")
6464

65-
if (major, minor) >= (2, 8):
66-
pass
65+
if (major, minor) >= (2, 9):
66+
extras_require_map.pop("fbgemm-gpu")
67+
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
68+
elif (major, minor) >= (2, 8):
69+
extras_require_map.pop("fbgemm-gpu")
70+
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
6771
elif (major, minor) >= (2, 7):
6872
_install_requires.pop(_install_requires.index(xformers_version))
6973
if patch == 0:

src/axolotl/core/builders/causal.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
EarlyStoppingCallback,
1313
Trainer,
1414
)
15-
from trl.trainer.utils import RewardDataCollatorWithPadding
15+
from trl.trainer.reward_trainer import DataCollatorForPreference
1616

1717
from axolotl.core.builders.base import TrainerBuilderBase
1818
from axolotl.core.trainers import (
@@ -453,7 +453,7 @@ def build_collator(
453453
BatchSamplerDataCollatorForSeq2Seq,
454454
DataCollatorForSeq2Seq,
455455
DataCollatorWithFlattening,
456-
RewardDataCollatorWithPadding,
456+
DataCollatorForPreference,
457457
]
458458
]
459459
collator_args = [self.tokenizer]
@@ -470,7 +470,10 @@ def build_collator(
470470
if kwargs and isinstance(kwargs, dict):
471471
kwargs.update(collator_cls_and_kwargs[1])
472472
elif self.cfg.reward_model:
473-
collator = RewardDataCollatorWithPadding
473+
collator = DataCollatorForPreference
474+
tokenizer = collator_args.pop(0)
475+
kwargs["pad_token_id"] = tokenizer.pad_token_id
476+
kwargs.pop("padding")
474477
elif use_batch_sampler_collator:
475478
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
476479
# supported multipack models, or non-flash-attention llama

src/axolotl/prompt_strategies/bradley_terry/chat_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def _tokenize_single_prompt(self, prompt):
7171
]
7272

7373
return {
74-
"input_ids_chosen": chosen_tokenized["input_ids"],
74+
"chosen_input_ids": chosen_tokenized["input_ids"],
7575
"attention_mask_chosen": chosen_tokenized["attention_mask"],
7676
"labels_chosen": 1.0,
77-
"input_ids_rejected": rejected_tokenized["input_ids"],
77+
"rejected_input_ids": rejected_tokenized["input_ids"],
7878
"attention_mask_rejected": rejected_tokenized["attention_mask"],
7979
"labels_rejected": 0.0,
8080
}

0 commit comments

Comments
 (0)