Skip to content

Commit bdbbc71

Browse files
authored
[grpo] deprecated params for 3.6 (#4848)
* deprecated params for 3.6 * check mp for grpo
1 parent 801c45a commit bdbbc71

File tree

3 files changed

+5
-28
lines changed

3 files changed

+5
-28
lines changed

swift/llm/argument/rlhf_args.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from swift.llm import MODEL_MAPPING
77
from swift.trainers.arguments import GRPOArgumentsMixin, RLHFArgumentsMixin
8-
from swift.utils import get_logger, is_master, set_default_ddp_config
8+
from swift.utils import get_logger, is_master, is_mp, set_default_ddp_config
99
from .train_args import TrainArguments
1010

1111
logger = get_logger()
@@ -155,7 +155,6 @@ def __post_init__(self):
155155
def _init_grpo(self):
156156
if self.rlhf_type == 'grpo':
157157
if self.use_vllm:
158-
os.environ['USE_FAST_INFERENCE'] = '1'
159158
set_default_ddp_config()
160159
if self.async_generate or not self.use_vllm:
161160
self.sleep_level = 0
@@ -255,7 +254,9 @@ def _check_grpo(self):
255254
trl_version = version.parse(trl.__version__)
256255
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
257256
'Please update it by running: pip install -U trl')
258-
257+
if is_mp() and self.use_vllm:
258+
raise ValueError('GRPO with vLLM is not compatible with `device_map`. '
259+
'Please set NPROC_PER_NODE equal to num_processes.')
259260
if self.use_liger_kernel:
260261
assert trl_version >= version.parse('0.18')
261262
if self.delta is not None:
@@ -308,25 +309,6 @@ def _deprecated_warning(self):
308309
if self.rlhf_type != 'grpo':
309310
return
310311

311-
if self.tensor_parallel_size is not None:
312-
logger.warning(
313-
"The parameter 'tensor_parallel_size' has been deprecated and will be removed in version 3.6. "
314-
"It is recommended to use 'vllm_tensor_parallel_size' instead.")
315-
self.vllm_tensor_parallel_size = self.tensor_parallel_size
316-
317-
if self.vllm_device is not None:
318-
logger.warning("The parameter 'vllm_device' has been deprecated and will be removed in version 3.6. ")
319-
320-
if self.vllm_max_num_seqs is not None:
321-
logger.warning("The parameter 'vllm_max_num_seqs' is automatically set, "
322-
'and has been deprecated and will be removed in version 3.6. ')
323-
324-
if self.num_infer_workers is not None:
325-
logger.warning(
326-
"The parameter 'num_infer_workers' has been deprecated and will be removed in version 3.6. "
327-
'If you wish to use colocate mode, please use `vllm_mode colocate` instead. '
328-
'If you wish to use async mode, please use `vllm_mode server` and external vLLM server instead.')
329-
330312
if self.multi_turn_func:
331313
logger.warning("The parameter 'multi_turn_func' has been deprecated and will be removed in version 3.7. "
332314
"Please use 'multi_turn_scheduler' instead")

swift/trainers/arguments.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,11 @@ class GRPOArgumentsMixin:
155155
top_k: int = 50
156156
top_p: float = 0.9
157157
repetition_penalty: float = 1.
158-
num_infer_workers: Optional[int] = None # deprecated
159158
# vllm
160159
vllm_mode: Literal['server', 'colocate'] = 'colocate'
161160
# internal vllm (colocate)
162-
vllm_device: Optional[List[str]] = None # deprecated
163161
vllm_gpu_memory_utilization: float = 0.9
164162
vllm_max_model_len: Optional[int] = None
165-
vllm_max_num_seqs: Optional[int] = None # deprecated
166163
vllm_enforce_eager: bool = False
167164
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
168165
vllm_enable_prefix_caching: bool = True
@@ -195,7 +192,6 @@ class GRPOArgumentsMixin:
195192
ref_model_mixup_alpha: float = 0.6
196193

197194
async_generate: bool = False
198-
tensor_parallel_size: Optional[int] = None # deprecated
199195

200196
sleep_level: int = 0
201197
move_model_batches: Optional[int] = None

swift/utils/env.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def is_dist():
7171
def is_mp() -> bool:
7272
if use_torchacc():
7373
return False
74-
if strtobool(os.environ.get('USE_FAST_INFERENCE', 'false')):
75-
return False
74+
7675
from swift.utils import get_device_count
7776
n_gpu = get_device_count()
7877
local_world_size = get_dist_setting()[3]

0 commit comments

Comments
 (0)