|
5 | 5 |
|
6 | 6 | from swift.llm import MODEL_MAPPING
|
7 | 7 | from swift.trainers.arguments import GRPOArgumentsMixin, RLHFArgumentsMixin
|
8 |
| -from swift.utils import get_logger, is_master, set_default_ddp_config |
| 8 | +from swift.utils import get_logger, is_master, is_mp, set_default_ddp_config |
9 | 9 | from .train_args import TrainArguments
|
10 | 10 |
|
11 | 11 | logger = get_logger()
|
@@ -155,7 +155,6 @@ def __post_init__(self):
|
155 | 155 | def _init_grpo(self):
|
156 | 156 | if self.rlhf_type == 'grpo':
|
157 | 157 | if self.use_vllm:
|
158 |
| - os.environ['USE_FAST_INFERENCE'] = '1' |
159 | 158 | set_default_ddp_config()
|
160 | 159 | if self.async_generate or not self.use_vllm:
|
161 | 160 | self.sleep_level = 0
|
@@ -255,7 +254,9 @@ def _check_grpo(self):
|
255 | 254 | trl_version = version.parse(trl.__version__)
|
256 | 255 | assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
|
257 | 256 | 'Please update it by running: pip install -U trl')
|
258 |
| - |
| 257 | + if is_mp() and self.use_vllm: |
| 258 | + raise ValueError('GRPO with vLLM is not compatible with `device_map`. ' |
| 259 | + 'Please set NPROC_PER_NODE equal to num_processes.') |
259 | 260 | if self.use_liger_kernel:
|
260 | 261 | assert trl_version >= version.parse('0.18')
|
261 | 262 | if self.delta is not None:
|
@@ -308,25 +309,6 @@ def _deprecated_warning(self):
|
308 | 309 | if self.rlhf_type != 'grpo':
|
309 | 310 | return
|
310 | 311 |
|
311 |
| - if self.tensor_parallel_size is not None: |
312 |
| - logger.warning( |
313 |
| - "The parameter 'tensor_parallel_size' has been deprecated and will be removed in version 3.6. " |
314 |
| - "It is recommended to use 'vllm_tensor_parallel_size' instead.") |
315 |
| - self.vllm_tensor_parallel_size = self.tensor_parallel_size |
316 |
| - |
317 |
| - if self.vllm_device is not None: |
318 |
| - logger.warning("The parameter 'vllm_device' has been deprecated and will be removed in version 3.6. ") |
319 |
| - |
320 |
| - if self.vllm_max_num_seqs is not None: |
321 |
| - logger.warning("The parameter 'vllm_max_num_seqs' is automatically set, " |
322 |
| - 'and has been deprecated and will be removed in version 3.6. ') |
323 |
| - |
324 |
| - if self.num_infer_workers is not None: |
325 |
| - logger.warning( |
326 |
| - "The parameter 'num_infer_workers' has been deprecated and will be removed in version 3.6. " |
327 |
| - 'If you wish to use colocate mode, please use `vllm_mode colocate` instead. ' |
328 |
| - 'If you wish to use async mode, please use `vllm_mode server` and external vLLM server instead.') |
329 |
| - |
330 | 312 | if self.multi_turn_func:
|
331 | 313 | logger.warning("The parameter 'multi_turn_func' has been deprecated and will be removed in version 3.7. "
|
332 | 314 | "Please use 'multi_turn_scheduler' instead")
|
|
0 commit comments