File tree Expand file tree Collapse file tree 4 files changed +25
-13
lines changed Expand file tree Collapse file tree 4 files changed +25
-13
lines changed Original file line number Diff line number Diff line change @@ -82,6 +82,12 @@ class VllmServeCliArgs:
82
82
"hardware support this feature."
83
83
},
84
84
)
85
+ serve_module : Optional [str ] = field (
86
+ default = None ,
87
+ metadata = {
88
+ "help" : "Module to serve. If not set, the default module will be used."
89
+ },
90
+ )
85
91
86
92
87
93
@dataclass
Original file line number Diff line number Diff line change 6
6
from typing import Union
7
7
8
8
from trl .scripts .vllm_serve import ScriptArguments
9
- from trl .scripts .vllm_serve import main as vllm_serve_main
10
9
11
10
from axolotl .cli .config import load_cfg
12
11
@@ -28,6 +27,9 @@ def do_vllm_serve(
28
27
cfg = load_cfg (config )
29
28
model = cfg .base_model
30
29
30
+ serve_module = cli_args .get ("serve_module" , "trl.scripts.vllm_serve" )
31
+ vllm_serve_main = getattr (__import__ (serve_module , fromlist = ["main" ]), "main" )
32
+
31
33
tensor_parallel_size = (
32
34
cli_args .get ("tensor_parallel_size" ) or cfg .vllm .tensor_parallel_size
33
35
)
Original file line number Diff line number Diff line change @@ -1197,6 +1197,10 @@ def build(self, total_num_steps):
1197
1197
else :
1198
1198
raise ValueError (f"Unsupported RL: { self .cfg .rl } " )
1199
1199
1200
+ if self .cfg .plugins :
1201
+ plugin_manager = PluginManager .get_instance ()
1202
+ trainer_cls = plugin_manager .get_trainer_cls (self .cfg )
1203
+
1200
1204
sig = inspect .signature (trainer_cls )
1201
1205
if "tokenizer" in sig .parameters .keys ():
1202
1206
trainer_kwargs ["tokenizer" ] = self .tokenizer
Original file line number Diff line number Diff line change @@ -1149,18 +1149,18 @@ def check_kto_config(cls, data):
1149
1149
1150
1150
return data
1151
1151
1152
- @model_validator (mode = "before" )
1153
- @classmethod
1154
- def check_grpo_peft_liger (cls , data ):
1155
- if (
1156
- data .get ("rl" ) == "grpo"
1157
- and data .get ("trl" , {})
1158
- and data .get ("trl" ).get ("use_liger_loss" )
1159
- and data .get ("adapter" )
1160
- ):
1161
- raise ValueError ("PEFT + GRPO + Liger is not yet supported" )
1162
- return data
1163
-
1152
+ # @model_validator(mode="before")
1153
+ # @classmethod
1154
+ # def check_grpo_peft_liger(cls, data):
1155
+ # if (
1156
+ # data.get("rl") == "grpo"
1157
+ # and data.get("trl", {})
1158
+ # and data.get("trl").get("use_liger_loss")
1159
+ # and data.get("adapter")
1160
+ # ):
1161
+ # raise ValueError("PEFT + GRPO + Liger is not yet supported")
1162
+ # return data
1163
+ #
1164
1164
@model_validator (mode = "before" )
1165
1165
@classmethod
1166
1166
def check_grpo_liger_sequence_parallel (cls , data ):
You can’t perform that action at this time.
0 commit comments