Skip to content

Commit e5d7c0f

Browse files
committed
compat trl 0.15 (#5905)
1 parent fd08d4b commit e5d7c0f

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

swift/llm/model/register.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_matched_model_group(self, model_name: str) -> Optional[ModelGroup]:
9191
for key in ['ms_model_id', 'hf_model_id', 'model_path']:
9292
value = getattr(model, key)
9393

94-
if isinstance(value, str) and model_name == value.rsplit('/', 1)[-1]:
94+
if isinstance(value, str) and model_name == value.rsplit('/', 1)[-1].lower():
9595
return model_group
9696

9797
def check_requires(self, model_info=None):
@@ -435,7 +435,7 @@ def get_all_models() -> List[str]:
435435

436436

437437
def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]:
438-
model_name = get_model_name(model_id_or_path)
438+
model_name = get_model_name(model_id_or_path).lower()
439439
for model_type, model_meta in MODEL_MAPPING.items():
440440
model_group = ModelMeta.get_matched_model_group(model_meta, model_name)
441441
if model_group is not None:

swift/megatron/utils/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
241241

242242
def convert_mcore2hf(args: ExportArguments) -> None:
243243
from swift.megatron import prepare_mcore_model, adapter_state_dict_context
244-
hf_model, template = prepare_model_template(
245-
args, load_model=args.to_hf, patch_offload=not args.test_convert_precision)
244+
_, template = prepare_model_template(args, load_model=False)
246245
processor = template.processor
247246

248247
megatron_model_meta = get_megatron_model_meta(args.model_type)
@@ -284,6 +283,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
284283
mg_model = peft_model.merge_and_unload()
285284
logger.info('Megatron model created successfully.')
286285
if args.to_hf:
286+
hf_model = prepare_model_template(args, patch_offload=not args.test_convert_precision)[0]
287287
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
288288
if args.test_convert_precision:
289289
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)

swift/trainers/rlhf_arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List
2+
from typing import List, Optional
33

44
from trl import CPOConfig as HfCPOConfig
55
from trl import DPOConfig as HfDPOConfig
@@ -15,7 +15,7 @@
1515

1616
@dataclass
1717
class DPOConfig(SwiftArgumentsMixin, HfDPOConfig):
18-
pass
18+
ld_alpha: Optional[float] = None # compat trl==0.15
1919

2020

2121
@dataclass

swift/trainers/rlhf_trainer/dpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from accelerate.utils import gather_object
88
from peft import PeftModel
99
from transformers import PreTrainedModel
10+
from transformers.utils.versions import require_version
1011
from trl import DPOTrainer as HFDPOTrainer
1112
from trl.trainer.dpo_config import DPOConfig
1213
from trl.trainer.utils import RunningMoments, selective_log_softmax
@@ -70,6 +71,8 @@ def __init__(self,
7071

7172
if 'bco_pair' in loss_types:
7273
self.running = RunningMoments(self.accelerator)
74+
if self.args.ld_alpha is not None:
75+
require_version('trl>=0.18', '`ld_alpha` requires that "trl>=0.18".')
7376
if self.template.packing:
7477
self.accelerator.gather_for_metrics = new_gather_function
7578

0 commit comments

Comments
 (0)