Skip to content

Commit 085bc51

Browse files
committed
Merge branch 'main' into release/3.6
2 parents 2f8c8a5 + b41c78d commit 085bc51

File tree

23 files changed

+362
-396
lines changed

23 files changed

+362
-396
lines changed

swift/llm/argument/rlhf_args.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from swift.llm import MODEL_MAPPING
77
from swift.trainers.arguments import GRPOArgumentsMixin, RLHFArgumentsMixin
8-
from swift.utils import get_logger, is_master, set_default_ddp_config
8+
from swift.utils import get_logger, is_master, is_mp, set_default_ddp_config
99
from .train_args import TrainArguments
1010

1111
logger = get_logger()
@@ -155,7 +155,6 @@ def __post_init__(self):
155155
def _init_grpo(self):
156156
if self.rlhf_type == 'grpo':
157157
if self.use_vllm:
158-
os.environ['USE_FAST_INFERENCE'] = '1'
159158
set_default_ddp_config()
160159
if self.async_generate or not self.use_vllm:
161160
self.sleep_level = 0
@@ -255,7 +254,9 @@ def _check_grpo(self):
255254
trl_version = version.parse(trl.__version__)
256255
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
257256
'Please update it by running: pip install -U trl')
258-
257+
if is_mp() and self.use_vllm:
258+
raise ValueError('GRPO with vLLM is not compatible with `device_map`. '
259+
'Please set NPROC_PER_NODE equal to num_processes.')
259260
if self.use_liger_kernel:
260261
assert trl_version >= version.parse('0.18')
261262
if self.delta is not None:
@@ -308,25 +309,6 @@ def _deprecated_warning(self):
308309
if self.rlhf_type != 'grpo':
309310
return
310311

311-
if self.tensor_parallel_size is not None:
312-
logger.warning(
313-
"The parameter 'tensor_parallel_size' has been deprecated and will be removed in version 3.6. "
314-
"It is recommended to use 'vllm_tensor_parallel_size' instead.")
315-
self.vllm_tensor_parallel_size = self.tensor_parallel_size
316-
317-
if self.vllm_device is not None:
318-
logger.warning("The parameter 'vllm_device' has been deprecated and will be removed in version 3.6. ")
319-
320-
if self.vllm_max_num_seqs is not None:
321-
logger.warning("The parameter 'vllm_max_num_seqs' is automatically set, "
322-
'and has been deprecated and will be removed in version 3.6. ')
323-
324-
if self.num_infer_workers is not None:
325-
logger.warning(
326-
"The parameter 'num_infer_workers' has been deprecated and will be removed in version 3.6. "
327-
'If you wish to use colocate mode, please use `vllm_mode colocate` instead. '
328-
'If you wish to use async mode, please use `vllm_mode server` and external vLLM server instead.')
329-
330312
if self.multi_turn_func:
331313
logger.warning("The parameter 'multi_turn_func' has been deprecated and will be removed in version 3.7. "
332314
"Please use 'multi_turn_scheduler' instead")

swift/megatron/train/trainers/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ def initialize_megatron(*_args, **kwargs):
4646
else:
4747
raise ValueError(
4848
'You are using a streaming training dataset. Please explicitly specify `--train_iters`.')
49-
if val_dataset is not None and args.eval_iters < 0:
50-
if hasattr(val_dataset, '__len__'):
49+
if args.eval_iters < 0:
50+
if val_dataset is None:
51+
args.eval_iters = 0
52+
elif hasattr(val_dataset, '__len__'):
5153
dataset_sample = len(val_dataset) // step_batch_size * step_batch_size
5254
args.eval_iters = max(dataset_sample // args.global_batch_size, 1)
5355
else:

swift/megatron/train/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ def swift_datasets_provider(train_val_test_num_samples):
1515
nonlocal val_dataset
1616
args = get_args()
1717
data_parallel_size = mpu.get_data_parallel_world_size()
18-
step_batch_size = \
19-
args.micro_batch_size * data_parallel_size
18+
step_batch_size = args.micro_batch_size * data_parallel_size
2019
# To avoid errors caused by the validation set being insufficient to complete a single step.
2120
if val_dataset is not None and len(val_dataset) < step_batch_size:
2221
val_dataset = None

swift/trainers/arguments.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,11 @@ class GRPOArgumentsMixin:
155155
top_k: int = 50
156156
top_p: float = 0.9
157157
repetition_penalty: float = 1.
158-
num_infer_workers: Optional[int] = None # deprecated
159158
# vllm
160159
vllm_mode: Literal['server', 'colocate'] = 'colocate'
161160
# internal vllm (colocate)
162-
vllm_device: Optional[List[str]] = None # deprecated
163161
vllm_gpu_memory_utilization: float = 0.9
164162
vllm_max_model_len: Optional[int] = None
165-
vllm_max_num_seqs: Optional[int] = None # deprecated
166163
vllm_enforce_eager: bool = False
167164
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
168165
vllm_enable_prefix_caching: bool = True
@@ -195,7 +192,6 @@ class GRPOArgumentsMixin:
195192
ref_model_mixup_alpha: float = 0.6
196193

197194
async_generate: bool = False
198-
tensor_parallel_size: Optional[int] = None # deprecated
199195

200196
sleep_level: int = 0
201197
move_model_batches: Optional[int] = None

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,8 +1094,20 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
10941094
InferRequest.remove_response(messages)
10951095
template_inputs, _ = StdTemplateInputs.from_dict({'messages': messages})
10961096
res_context_list, _, _ = self.template._swift_encode(template_inputs)
1097-
prompts_text.append(''.join(elem for elem in res_context_list if isinstance(elem, str)))
10981097

1098+
# check the type and convert
1099+
processed_context = []
1100+
for context in res_context_list:
1101+
if isinstance(context, str):
1102+
processed_context.append(context)
1103+
elif isinstance(context, list) and all(isinstance(x, int) for x in context):
1104+
# decode the token ID to text
1105+
decoded_text = self.template.tokenizer.decode(context)
1106+
processed_context.append(decoded_text)
1107+
else:
1108+
# other type value ,just add to process_context
1109+
processed_context.append(str(context))
1110+
prompts_text.append(''.join(processed_context))
10991111
return prompts_text
11001112

11011113
@profiling_decorator
@@ -1421,7 +1433,7 @@ def _process_infer_requests_images(self, infer_requests: InputsType):
14211433
return
14221434

14231435
def old_policy(self):
1424-
return self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps
1436+
return self.num_iterations > 1 or self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0
14251437

14261438
@property
14271439
def _queue(self):
@@ -1580,18 +1592,40 @@ def is_async_generate_eval_rollout_done(self):
15801592
def is_async_generate_train_rollout_done(self):
15811593
return not self.train_queue.empty()
15821594

1583-
def inputs_to_rolloutrequest(self, inputs: InputsType) -> RolloutInferRequest:
1595+
def inputs_to_rolloutrequest(self, inputs: InputsType) -> List[RolloutInferRequest]:
1596+
"""Convert a list of inputs to a list of RolloutInferRequest objects
1597+
1598+
If the input contains a 'data_dict' key, it will be used as the base for the new data_dict.
1599+
For other keys, if they overlap with keys in data_dict, the values from data_dict will be used.
1600+
Non-overlapping keys will be added to data_dict.
1601+
1602+
Args:
1603+
inputs: List of input dictionaries
15841604
1605+
Returns:
1606+
List of RolloutInferRequest objects
1607+
"""
15851608
request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects']
1586-
infer_requests = [
1587-
RolloutInferRequest(
1588-
**{
1589-
**{k: request[k]
1590-
for k in request_keys if k in request}, 'data_dict':
1591-
{k: request[k]
1592-
for k in request if k not in request_keys}
1593-
}) for request in inputs
1594-
]
1609+
infer_requests = []
1610+
1611+
for request in inputs:
1612+
# Get the base data_dict if it exists in the input
1613+
base_data_dict = {}
1614+
if 'data_dict' in request:
1615+
if isinstance(request['data_dict'], dict):
1616+
base_data_dict = request['data_dict']
1617+
else:
1618+
raise ValueError('data_dict exists but is not a dictionary')
1619+
1620+
# Collect all non-request_keys items as extra fields
1621+
extra_data = {k: request[k] for k in request if k not in request_keys and k != 'data_dict'}
1622+
1623+
# Merge the data_dict, keeping keys from base_data_dict as priority
1624+
final_data_dict = {**extra_data, **base_data_dict}
1625+
1626+
# Create RolloutInferRequest instance
1627+
req_args = {k: request[k] for k in request_keys if k in request}
1628+
infer_requests.append(RolloutInferRequest(**req_args, data_dict=final_data_dict))
15951629

15961630
return infer_requests
15971631

swift/ui/llm_grpo/external_rollout.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class LLMRollout(BaseUI):
110110

111111
@classmethod
112112
def do_build_ui(cls, base_tab: Type['BaseUI']):
113-
with gr.Accordion(elem_id='llm_rollout', visible=False):
113+
with gr.Accordion(elem_id='llm_rollout', open=False, visible=False):
114114
default_device = 'cpu'
115115
device_count = get_device_count()
116116
if device_count > 0:
@@ -119,7 +119,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
119119
with gr.Row():
120120
gr.Textbox(elem_id='tensor_parallel_size', lines=1, value='1', scale=4)
121121
gr.Textbox(elem_id='data_parallel_size', lines=1, value='1', scale=4)
122-
gr.Textbox(elem_id='max_model_len', lines=1, value='', scale=4)
123122
gr.Slider(elem_id='gpu_memory_utilization', minimum=0.0, maximum=1.0, step=0.05, value=0.9, scale=4)
124123
with gr.Row(equal_height=True):
125124
gr.Dropdown(

swift/ui/llm_grpo/external_runtime.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class RolloutRuntime(Runtime):
5656
'en': 'Logging content'
5757
},
5858
'info': {
59-
'zh': '如果日志无更新请再次点击"展示日志内容"',
60-
'en': 'Please press "Show log" if the log content is not updating'
59+
'zh': '如果日志无更新请再次点击"展示rollout状态"',
60+
'en': 'Please press "Show running status" if the log content is not updating'
6161
}
6262
},
6363
'rollout_running_tasks': {
@@ -90,6 +90,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
9090
with gr.Blocks():
9191
with gr.Row(equal_height=True):
9292
gr.Dropdown(elem_id='rollout_running_tasks', scale=10, allow_custom_value=True)
93+
with gr.Row(equal_height=True):
9394
gr.Button(elem_id='rollout_refresh_tasks', scale=1, variant='primary')
9495
gr.Button(elem_id='rollout_show_log', scale=1, variant='primary')
9596
gr.Button(elem_id='rollout_stop_show_log', scale=1)

0 commit comments

Comments
 (0)