|
5 | 5 |
|
6 | 6 | from swift.llm import MODEL_MAPPING |
7 | 7 | from swift.trainers.arguments import GRPOArgumentsMixin, RLHFArgumentsMixin |
8 | | -from swift.utils import get_logger, is_master, set_default_ddp_config |
| 8 | +from swift.utils import get_logger, is_master, is_mp, set_default_ddp_config |
9 | 9 | from .train_args import TrainArguments |
10 | 10 |
|
11 | 11 | logger = get_logger() |
@@ -155,7 +155,6 @@ def __post_init__(self): |
155 | 155 | def _init_grpo(self): |
156 | 156 | if self.rlhf_type == 'grpo': |
157 | 157 | if self.use_vllm: |
158 | | - os.environ['USE_FAST_INFERENCE'] = '1' |
159 | 158 | set_default_ddp_config() |
160 | 159 | if self.async_generate or not self.use_vllm: |
161 | 160 | self.sleep_level = 0 |
@@ -255,7 +254,9 @@ def _check_grpo(self): |
255 | 254 | trl_version = version.parse(trl.__version__) |
256 | 255 | assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. ' |
257 | 256 | 'Please update it by running: pip install -U trl') |
258 | | - |
| 257 | + if is_mp() and self.use_vllm: |
| 258 | + raise ValueError('GRPO with vLLM is not compatible with `device_map`. ' |
| 259 | + 'Please set NPROC_PER_NODE equal to num_processes.') |
259 | 260 | if self.use_liger_kernel: |
260 | 261 | assert trl_version >= version.parse('0.18') |
261 | 262 | if self.delta is not None: |
@@ -308,25 +309,6 @@ def _deprecated_warning(self): |
308 | 309 | if self.rlhf_type != 'grpo': |
309 | 310 | return |
310 | 311 |
|
311 | | - if self.tensor_parallel_size is not None: |
312 | | - logger.warning( |
313 | | - "The parameter 'tensor_parallel_size' has been deprecated and will be removed in version 3.6. " |
314 | | - "It is recommended to use 'vllm_tensor_parallel_size' instead.") |
315 | | - self.vllm_tensor_parallel_size = self.tensor_parallel_size |
316 | | - |
317 | | - if self.vllm_device is not None: |
318 | | - logger.warning("The parameter 'vllm_device' has been deprecated and will be removed in version 3.6. ") |
319 | | - |
320 | | - if self.vllm_max_num_seqs is not None: |
321 | | - logger.warning("The parameter 'vllm_max_num_seqs' is automatically set, " |
322 | | - 'and has been deprecated and will be removed in version 3.6. ') |
323 | | - |
324 | | - if self.num_infer_workers is not None: |
325 | | - logger.warning( |
326 | | - "The parameter 'num_infer_workers' has been deprecated and will be removed in version 3.6. " |
327 | | - 'If you wish to use colocate mode, please use `vllm_mode colocate` instead. ' |
328 | | - 'If you wish to use async mode, please use `vllm_mode server` and external vLLM server instead.') |
329 | | - |
330 | 312 | if self.multi_turn_func: |
331 | 313 | logger.warning("The parameter 'multi_turn_func' has been deprecated and will be removed in version 3.7. " |
332 | 314 | "Please use 'multi_turn_scheduler' instead") |
|
0 commit comments