Skip to content

Commit 7fa1089

Browse files
authored
Atropos support (axolotl-ai-cloud#2666) [skip ci]
* allow peft+liger+grpo and custom vllm serve for atropos support * set trainer class for RL
1 parent 80304c2 commit 7fa1089

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

src/axolotl/cli/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class VllmServeCliArgs:
8282
"hardware support this feature."
8383
},
8484
)
85+
serve_module: Optional[str] = field(
86+
default=None,
87+
metadata={
88+
"help": "Module to serve. If not set, the default module will be used."
89+
},
90+
)
8591

8692

8793
@dataclass

src/axolotl/cli/vllm_serve.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Union
77

88
from trl.scripts.vllm_serve import ScriptArguments
9-
from trl.scripts.vllm_serve import main as vllm_serve_main
109

1110
from axolotl.cli.config import load_cfg
1211

@@ -28,6 +27,9 @@ def do_vllm_serve(
2827
cfg = load_cfg(config)
2928
model = cfg.base_model
3029

30+
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
31+
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
32+
3133
tensor_parallel_size = (
3234
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
3335
)

src/axolotl/core/trainer_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,10 @@ def build(self, total_num_steps):
11971197
else:
11981198
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
11991199

1200+
if self.cfg.plugins:
1201+
plugin_manager = PluginManager.get_instance()
1202+
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
1203+
12001204
sig = inspect.signature(trainer_cls)
12011205
if "tokenizer" in sig.parameters.keys():
12021206
trainer_kwargs["tokenizer"] = self.tokenizer

src/axolotl/utils/schemas/config.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,18 +1149,18 @@ def check_kto_config(cls, data):
11491149

11501150
return data
11511151

1152-
@model_validator(mode="before")
1153-
@classmethod
1154-
def check_grpo_peft_liger(cls, data):
1155-
if (
1156-
data.get("rl") == "grpo"
1157-
and data.get("trl", {})
1158-
and data.get("trl").get("use_liger_loss")
1159-
and data.get("adapter")
1160-
):
1161-
raise ValueError("PEFT + GRPO + Liger is not yet supported")
1162-
return data
1163-
1152+
# @model_validator(mode="before")
1153+
# @classmethod
1154+
# def check_grpo_peft_liger(cls, data):
1155+
# if (
1156+
# data.get("rl") == "grpo"
1157+
# and data.get("trl", {})
1158+
# and data.get("trl").get("use_liger_loss")
1159+
# and data.get("adapter")
1160+
# ):
1161+
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
1162+
# return data
1163+
#
11641164
@model_validator(mode="before")
11651165
@classmethod
11661166
def check_grpo_liger_sequence_parallel(cls, data):

0 commit comments

Comments
 (0)